tdurzynski's picture
Update app.py
6f98b16 verified
raw
history blame
9.19 kB
import os
import logging
import gradio as gr
from dotenv import load_dotenv
from langchain.document_loaders import ArxivLoader, PyPDFLoader
from langchain.text_splitter import TokenTextSplitter
from langchain.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceHubEmbeddings
from langchain.chains import RetrievalQA
from langchain.chains.summarize import load_summarize_chain
from langchain_groq import ChatGroq
from transformers import pipeline
from PyPDF2 import PdfReader
from huggingface_hub import login
from groq import AsyncGroq, Groq
import asyncio
# Load environment variables
load_dotenv()
HUGGING_API_KEY = os.getenv("HUGGING_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
# Ensure API keys are set
if not HUGGING_API_KEY or not GROQ_API_KEY:
raise ValueError("API keys for HuggingFace or Groq are missing. Set them in your environment variables.")
# Configure Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Authenticate with Hugging Face
login(HUGGING_API_KEY)
# Load models and embeddings
embedding_model = HuggingFaceHubEmbeddings(huggingfacehub_api_token=HUGGING_API_KEY)
llm = ChatGroq(temperature=0, model_name="llama3-70b-8192", api_key=GROQ_API_KEY)
def display_results(result):
"""Format and display results properly."""
return "\n".join(result)
def summarize_text(text):
"""Summarize text using the Groq API."""
try:
sum_client = Groq(api_key=GROQ_API_KEY)
messages = [
{"role": "system", "content": "You are an excellent analyst who excels in summarization task. If I give you the whole text, you should summarize it."},
{"role": "user", "content": f"Summarize the paper: {text}"}
]
response = sum_client.chat.completions.create(
messages=messages,
model="llama3-70b-8192",
temperature=0,
max_tokens=8192,
top_p=1,
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Error summarizing text: {e}")
return "Error in summarization."
def summarize_pdf(pdf_file_path, max_length):
"""Extract text from a PDF and summarize it."""
try:
reader = PdfReader(pdf_file_path)
text = "\n".join(page.extract_text() or "" for page in reader.pages)
text_splitter = TokenTextSplitter(chunk_size=8192, chunk_overlap=1000)
chunks = text_splitter.split_text(text)
summary = ""
for chunk in chunks:
summary += summarize_text(chunk)
return summary
except Exception as e:
logger.error(f"Error summarizing PDF: {e}")
return "Failed to process the PDF."
def summarize_arxiv_pdf(query):
"""Summarize an arXiv paper given a query."""
try:
loader = ArxivLoader(query=query, load_max_docs=10)
documents = loader.load()
text_splitter = TokenTextSplitter(chunk_size=5700, chunk_overlap=100)
chunks = text_splitter.split_documents(documents)
ref_summary = ""
for chunk in chunks:
ref_summary += summarize_text(chunk.page_content)
arxiv_summary = loader.get_summaries_as_docs()
summaries = []
for doc in arxiv_summary:
title = doc.metadata.get("Title", "Unknown Title")
authors = doc.metadata.get("Authors", "Unknown Authors")
url = doc.metadata.get("Entry ID", "No URL")
summaries.append(f"**{title}**\n")
summaries.append(f"**Authors:** {authors}\n")
summaries.append(f"**View full paper:** [Link to paper]({url})\n")
summaries.append(f"**Summary:** {doc.page_content}\n")
summaries.append(f"**Enhanced Summary:**\n {ref_summary}")
return display_results(summaries)
except Exception as e:
logger.error(f"Error summarizing arXiv paper: {e}")
return "Failed to process arXiv paper."
client = AsyncGroq(api_key=GROQ_API_KEY)
async def chat_with_replit(message, history):
"""Chat functionality using Groq API."""
try:
messages = [{"role": "system", "content": "You are an assistant answering user questions."}]
for chat in history:
user_msg, assistant_msg = chat
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
response = await client.chat.completions.create(
messages=messages,
model="llama3-70b-8192",
temperature=0,
max_tokens=1024,
top_p=1,
stream=False, # Using non-streaming for simplicity in this integration.
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Chat error: {e}")
return "Error in chat response."
async def chat_with_replit_pdf(message, history, doi_num):
"""Chat with arXiv papers using document retrieval."""
try:
loader = ArxivLoader(query=str(doi_num), load_max_docs=10)
documents = loader.load_and_split()
metadata = documents[0].metadata
vector_store = Chroma.from_documents(documents, embedding_model)
def retrieve_relevant_content(user_query):
results = vector_store.similarity_search(user_query, k=3)
return "\n\n".join(doc.page_content for doc in results)
relevant_content = retrieve_relevant_content(message)
messages = [
{"role": "user", "content": message},
{"role": "system", "content": f"Answer based on this arXiv paper {doi_num}.\n"
f"Metadata: {metadata}.\n"
f"Relevant Content: {relevant_content}"}
]
response = await client.chat.completions.create(
messages=messages,
model="llama3-70b-8192",
temperature=0,
max_tokens=1024,
top_p=1,
stream=False,
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Error in chat with PDF: {e}")
return "Error processing chat with PDF."
# Define a synchronous wrapper for the async chat function
def chat_with_replit_sync(message, history):
return asyncio.run(chat_with_replit(message, history))
# Gradio UI
with gr.Blocks() as app:
# Tab for Local PDF Summarization
with gr.Tab(label="Local PDF Summarization"):
with gr.Row():
input_pdf = gr.File(label="Upload PDF file")
max_length_slider = gr.Slider(512, 4096, value=2048, step=512, label="Max Length")
summarize_pdf_btn = gr.Button(value="Summarize PDF")
with gr.Row():
output_pdf_summary = gr.Markdown(label="Summary", height=1000)
summarize_pdf_btn.click(summarize_pdf, inputs=[input_pdf, max_length_slider], outputs=output_pdf_summary)
# Tab for Arxiv Summarization
with gr.Tab(label="Arxiv Summarization"):
with gr.Column():
arxiv_number = gr.Textbox(label="Enter arXiv number, i.e 2502.02523")
summarize_btn = gr.Button(value="Summarize arXiv Paper")
with gr.Column():
output_summary = gr.Markdown(label="Summary", height=1000)
summarize_btn.click(summarize_arxiv_pdf, inputs=arxiv_number, outputs=output_summary)
# New Tab for Chat functionality
with gr.Tab(label="Chat with Assistant"):
gr.Markdown("### Chat with the Assistant")
with gr.Row():
chat_input = gr.Textbox(placeholder="Type your message here...", label="Your Message")
send_button = gr.Button("Send")
# A Markdown to display the conversation history (or you could use gr.Chatbot)
chat_output = gr.Markdown(label="Chat Output", height=300)
# Maintain chat history as a list of [user, assistant] pairs
chat_history = gr.State([])
# When the send button is clicked, update the chat history and get a response.
def update_chat(user_message, history):
# Append the new user message to history with an empty assistant response for now.
history = history or []
history.append([user_message, ""])
return history, history
def update_assistant_response(history):
# Get the last user message and call the chat function
user_message = history[-1][0]
response = chat_with_replit_sync(user_message, history[:-1])
# Update the last entry with the assistant's response
history[-1][1] = response
# Format the conversation for display
formatted = "\n\n".join([f"**User:** {u}\n\n**Assistant:** {a}" for u, a in history])
return history, formatted
send_button.click(update_chat, inputs=[chat_input, chat_history], outputs=[chat_history, chat_output])
send_button.click(update_assistant_response, inputs=chat_history, outputs=[chat_history, chat_output])
app.launch()