tdurzynski's picture
Create app.py
0bfd27d verified
raw
history blame
7.31 kB
import os
import logging
import gradio as gr
from dotenv import load_dotenv
from langchain.document_loaders import ArxivLoader, PyPDFLoader
from langchain.text_splitter import TokenTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
from langchain.chains import RetrievalQA
from langchain.chains.summarize import load_summarize_chain
from langchain_groq import ChatGroq
from transformers import pipeline
from PyPDF2 import PdfReader
from huggingface_hub import login
from groq import AsyncGroq, Groq
# Load environment variables
load_dotenv()
HUGGING_API_KEY = os.getenv("HUGGING_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
# Ensure API keys are set
if not HUGGING_API_KEY or not GROQ_API_KEY:
raise ValueError("API keys for HuggingFace or Groq are missing. Set them in your environment variables.")
# Configure Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Authenticate with Hugging Face
login(HUGGING_API_KEY)
# Load models and embeddings
embedding_model = HuggingFaceHubEmbeddings(huggingfacehub_api_token=HUGGING_API_KEY)
llm = ChatGroq(temperature=0, model_name="llama3-70b-8192", api_key=GROQ_API_KEY)
def display_results(result):
"""Format and display results properly."""
return "\n".join(result)
def summarize_text(text):
"""Summarize text using the Groq API."""
try:
sum_client = Groq(api_key=GROQ_API_KEY)
messages = [
{"role": "system", "content": "You are a summarizer. If I give you the whole text, you should summarize it."},
{"role": "user", "content": f"Summarize the paper: {text}"}
]
response = sum_client.chat.completions.create(
messages=messages,
model="llama3-70b-8192",
temperature=0,
max_tokens=8192,
top_p=1,
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Error summarizing text: {e}")
return "Error in summarization."
def summarize_pdf(pdf_file_path, max_length):
"""Extract text from a PDF and summarize it."""
try:
loader = PdfReader(pdf_file_path)
text = "\n".join(page.extract_text() or "" for page in loader.pages)
text_splitter = TokenTextSplitter(chunk_size=8192, chunk_overlap=1000)
chunks = text_splitter.split_text(text)
summary = ""
for chunk in chunks:
summary += summarize_text(chunk)
return summary
except Exception as e:
logger.error(f"Error summarizing PDF: {e}")
return "Failed to process the PDF."
def summarize_arxiv_pdf(query):
"""Summarize an arXiv paper given a query."""
try:
loader = ArxivLoader(query=query, load_max_docs=10)
documents = loader.load()
text_splitter = TokenTextSplitter(chunk_size=5700, chunk_overlap=100)
chunks = text_splitter.split_documents(documents)
ref_summary = ""
for chunk in chunks:
ref_summary += summarize_text(chunk.page_content)
arxiv_summary = loader.get_summaries_as_docs()
summaries = []
for doc in arxiv_summary:
title = doc.metadata.get("Title", "Unknown Title")
authors = doc.metadata.get("Authors", "Unknown Authors")
url = doc.metadata.get("Entry ID", "No URL")
summaries.append(f"**{title}**\n")
summaries.append(f"**Authors:** {authors}\n")
summaries.append(f"**View full paper:** [Link to paper]({url})\n")
summaries.append(f"**Summary:** {doc.page_content}\n")
summaries.append(f"**Enhanced Summary:**\n {ref_summary}")
return display_results(summaries)
except Exception as e:
logger.error(f"Error summarizing arXiv paper: {e}")
return "Failed to process arXiv paper."
client = AsyncGroq(api_key=GROQ_API_KEY)
async def chat_with_replit(message, history):
"""Chat functionality using Groq API."""
try:
messages = [{"role": "system", "content": "You are an assistant answering user questions."}]
for chat in history:
user, assistant = chat
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
stream = await client.chat.completions.create(
messages=messages,
model="llama3-70b-8192",
temperature=0,
max_tokens=1024,
top_p=1,
stream=True,
)
response_content = ""
async for chunk in stream:
if chunk.choices[0].delta.content:
response_content += chunk.choices[0].delta.content
yield response_content
except Exception as e:
logger.error(f"Chat error: {e}")
yield "Error in chat response."
async def chat_with_replit_pdf(message, history, doi_num):
"""Chat with arXiv papers using document retrieval."""
try:
loader = ArxivLoader(query=str(doi_num), load_max_docs=10)
documents = loader.load_and_split()
metadata = documents[0].metadata
vector_store = Chroma.from_documents(documents, embedding_model)
def retrieve_relevant_content(user_query):
results = vector_store.similarity_search(user_query, k=3)
return "\n\n".join(doc.page_content for doc in results)
relevant_content = retrieve_relevant_content(message)
messages = [
{"role": "user", "content": message},
{"role": "system", "content": f"Answer based on this arXiv paper {doi_num}.\n"
f"Metadata: {metadata}.\n"
f"Relevant Content: {relevant_content}"}
]
response = await client.chat.completions.create(
messages=messages,
model="llama3-70b-8192",
temperature=0,
max_tokens=1024,
top_p=1,
stream=False,
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Error in chat with PDF: {e}")
return "Error processing chat with PDF."
# Gradio UI
with gr.Blocks() as app:
with gr.Tab(label="Arxiv Summarization"):
with gr.Column():
arxiv_number = gr.Textbox(label="Enter arXiv number")
summarize_btn = gr.Button(value="Summarize arXiv Paper")
with gr.Column():
output_summary = gr.Markdown(label="Summary", height=1000)
summarize_btn.click(summarize_arxiv_pdf, inputs=arxiv_number, outputs=output_summary)
with gr.Tab(label="Local PDF Summarization"):
with gr.Row():
input_pdf = gr.File(label="Upload PDF file")
max_length_slider = gr.Slider(512, 4096, value=2048, step=512, label="Max Length")
summarize_pdf_btn = gr.Button(value="Summarize PDF")
with gr.Row():
output_pdf_summary = gr.Markdown(label="Summary", height=1000)
summarize_pdf_btn.click(summarize_pdf, inputs=[input_pdf, max_length_slider], outputs=output_pdf_summary)
app.launch()