Spaces:
Runtime error
Runtime error
| import os | |
| from dataclasses import asdict | |
| import pandas as pd | |
| import wandb | |
| from langchain.callbacks import get_openai_callback | |
| from langchain.chains.summarize import load_summarize_chain | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.document_loaders import DataFrameLoader | |
| from langchain.prompts import PromptTemplate | |
| from langchain.text_splitter import TokenTextSplitter | |
| from tqdm import tqdm | |
| from wandb.integration.langchain import WandbTracer | |
| from config import config | |
| def get_data( | |
| artifact_name: str = "gladiator/gradient_dissent_bot/yt_podcast_data:latest", | |
| total_episodes: int = None, | |
| ): | |
| podcast_artifact = wandb.use_artifact(artifact_name, type="dataset") | |
| podcast_artifact_dir = podcast_artifact.download(config.root_data_dir) | |
| df = pd.read_csv(os.path.join(podcast_artifact_dir, "yt_data.csv")) | |
| if total_episodes is not None: | |
| df = df.iloc[:total_episodes] | |
| return df | |
| def summarize_episode(episode_df: pd.DataFrame): | |
| # load docs into langchain format | |
| loader = DataFrameLoader(episode_df, page_content_column="transcript") | |
| data = loader.load() | |
| # split the documents | |
| text_splitter = TokenTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0) | |
| docs = text_splitter.split_documents(data) | |
| print(f"Number of documents for podcast {data[0].metadata['title']}: {len(docs)}") | |
| # initialize LLM | |
| llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
| # define map prompt | |
| map_prompt = """Write a concise summary of the following short transcript from a podcast. | |
| Don't add your opinions or interpretations. | |
| {text} | |
| CONCISE SUMMARY:""" | |
| # define combine prompt | |
| combine_prompt = """You have been provided with summaries of chunks of transcripts from a podcast. | |
| Your task is to merge these intermediate summaries to create a brief and comprehensive summary of the entire podcast. | |
| The summary should encompass all the crucial points of the podcast. | |
| Ensure that the summary is atleast 2 paragraph long and effectively captures the essence of the podcast. | |
| {text} | |
| SUMMARY:""" | |
| map_prompt_template = PromptTemplate(template=map_prompt, input_variables=["text"]) | |
| combine_prompt_template = PromptTemplate(template=combine_prompt, input_variables=["text"]) | |
| # initialize the summarizer chain | |
| chain = load_summarize_chain( | |
| llm, | |
| chain_type="map_reduce", | |
| return_intermediate_steps=True, | |
| map_prompt=map_prompt_template, | |
| combine_prompt=combine_prompt_template, | |
| ) | |
| summary = chain({"input_documents": docs}) | |
| return summary | |
| if __name__ == "__main__": | |
| # initialize wandb tracer | |
| WandbTracer.init( | |
| { | |
| "project": "gradient_dissent_bot", | |
| "name": "summarize_3", | |
| "job_type": "summarize", | |
| "config": asdict(config), | |
| } | |
| ) | |
| # get scraped data | |
| df = get_data(artifact_name=config.yt_podcast_data_artifact, total_episodes=3) | |
| summaries = [] | |
| with get_openai_callback() as cb: | |
| for episode in tqdm(df.iterrows(), total=len(df), desc="Summarizing episodes"): | |
| episode_data = episode[1].to_frame().T | |
| summary = summarize_episode(episode_data) | |
| summaries.append(summary["output_text"]) | |
| print("*" * 25) | |
| print(cb) | |
| print("*" * 25) | |
| wandb.log( | |
| { | |
| "total_prompt_tokens": cb.prompt_tokens, | |
| "total_completion_tokens": cb.completion_tokens, | |
| "total_tokens": cb.total_tokens, | |
| "total_cost": cb.total_cost, | |
| } | |
| ) | |
| df["summary"] = summaries | |
| # log to wandb artifact | |
| path_to_save = os.path.join(config.root_data_dir, "summary_data.csv") | |
| df.to_csv(path_to_save) | |
| artifact = wandb.Artifact("summary_data", type="dataset") | |
| artifact.add_file(path_to_save) | |
| wandb.log_artifact(artifact) | |
| # create wandb table | |
| table = wandb.Table(dataframe=df) | |
| wandb.log({"summary_data": table}) | |
| WandbTracer.finish() | |