Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import gradio as gr | |
from dotenv import load_dotenv | |
from langchain.document_loaders import ArxivLoader, PyPDFLoader | |
from langchain.text_splitter import TokenTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain_community.embeddings import HuggingFaceHubEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain_groq import ChatGroq | |
from transformers import pipeline | |
from PyPDF2 import PdfReader | |
from huggingface_hub import login | |
from groq import AsyncGroq, Groq | |
import asyncio | |
# Load environment variables | |
load_dotenv() | |
HUGGING_API_KEY = os.getenv("HUGGING_API_KEY") | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
# Ensure API keys are set | |
if not HUGGING_API_KEY or not GROQ_API_KEY: | |
raise ValueError("API keys for HuggingFace or Groq are missing. Set them in your environment variables.") | |
# Configure Logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Authenticate with Hugging Face | |
login(HUGGING_API_KEY) | |
# Load models and embeddings | |
embedding_model = HuggingFaceHubEmbeddings(huggingfacehub_api_token=HUGGING_API_KEY) | |
llm = ChatGroq(temperature=0, model_name="llama3-70b-8192", api_key=GROQ_API_KEY) | |
def display_results(result): | |
"""Format and display results properly.""" | |
return "\n".join(result) | |
def summarize_text(text): | |
"""Summarize text using the Groq API.""" | |
try: | |
sum_client = Groq(api_key=GROQ_API_KEY) | |
messages = [ | |
{"role": "system", "content": "You are an excellent analyst who excels in summarization task. If I give you the whole text, you should summarize it."}, | |
{"role": "user", "content": f"Summarize the paper: {text}"} | |
] | |
response = sum_client.chat.completions.create( | |
messages=messages, | |
model="llama3-70b-8192", | |
temperature=0, | |
max_tokens=8192, | |
top_p=1, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
logger.error(f"Error summarizing text: {e}") | |
return "Error in summarization." | |
def summarize_pdf(pdf_file_path, max_length): | |
"""Extract text from a PDF and summarize it.""" | |
try: | |
reader = PdfReader(pdf_file_path) | |
text = "\n".join(page.extract_text() or "" for page in reader.pages) | |
text_splitter = TokenTextSplitter(chunk_size=8192, chunk_overlap=1000) | |
chunks = text_splitter.split_text(text) | |
summary = "" | |
for chunk in chunks: | |
summary += summarize_text(chunk) | |
return summary | |
except Exception as e: | |
logger.error(f"Error summarizing PDF: {e}") | |
return "Failed to process the PDF." | |
def summarize_arxiv_pdf(query): | |
"""Summarize an arXiv paper given a query.""" | |
try: | |
loader = ArxivLoader(query=query, load_max_docs=10) | |
documents = loader.load() | |
text_splitter = TokenTextSplitter(chunk_size=5700, chunk_overlap=100) | |
chunks = text_splitter.split_documents(documents) | |
ref_summary = "" | |
for chunk in chunks: | |
ref_summary += summarize_text(chunk.page_content) | |
arxiv_summary = loader.get_summaries_as_docs() | |
summaries = [] | |
for doc in arxiv_summary: | |
title = doc.metadata.get("Title", "Unknown Title") | |
authors = doc.metadata.get("Authors", "Unknown Authors") | |
url = doc.metadata.get("Entry ID", "No URL") | |
summaries.append(f"**{title}**\n") | |
summaries.append(f"**Authors:** {authors}\n") | |
summaries.append(f"**View full paper:** [Link to paper]({url})\n") | |
summaries.append(f"**Summary:** {doc.page_content}\n") | |
summaries.append(f"**Enhanced Summary:**\n {ref_summary}") | |
return display_results(summaries) | |
except Exception as e: | |
logger.error(f"Error summarizing arXiv paper: {e}") | |
return "Failed to process arXiv paper." | |
client = AsyncGroq(api_key=GROQ_API_KEY) | |
async def chat_with_replit(message, history): | |
"""Chat functionality using Groq API.""" | |
try: | |
messages = [{"role": "system", "content": "You are an assistant answering user questions."}] | |
for chat in history: | |
user_msg, assistant_msg = chat | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": message}) | |
response = await client.chat.completions.create( | |
messages=messages, | |
model="llama3-70b-8192", | |
temperature=0, | |
max_tokens=1024, | |
top_p=1, | |
stream=False, # Using non-streaming for simplicity in this integration. | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
logger.error(f"Chat error: {e}") | |
return "Error in chat response." | |
async def chat_with_replit_pdf(message, history, doi_num): | |
"""Chat with arXiv papers using document retrieval.""" | |
try: | |
loader = ArxivLoader(query=str(doi_num), load_max_docs=10) | |
documents = loader.load_and_split() | |
metadata = documents[0].metadata | |
vector_store = Chroma.from_documents(documents, embedding_model) | |
def retrieve_relevant_content(user_query): | |
results = vector_store.similarity_search(user_query, k=3) | |
return "\n\n".join(doc.page_content for doc in results) | |
relevant_content = retrieve_relevant_content(message) | |
messages = [ | |
{"role": "user", "content": message}, | |
{"role": "system", "content": f"Answer based on this arXiv paper {doi_num}.\n" | |
f"Metadata: {metadata}.\n" | |
f"Relevant Content: {relevant_content}"} | |
] | |
response = await client.chat.completions.create( | |
messages=messages, | |
model="llama3-70b-8192", | |
temperature=0, | |
max_tokens=1024, | |
top_p=1, | |
stream=False, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
logger.error(f"Error in chat with PDF: {e}") | |
return "Error processing chat with PDF." | |
# Define a synchronous wrapper for the async chat function | |
def chat_with_replit_sync(message, history): | |
return asyncio.run(chat_with_replit(message, history)) | |
# Gradio UI | |
with gr.Blocks() as app: | |
# Tab for Local PDF Summarization | |
with gr.Tab(label="Local PDF Summarization"): | |
with gr.Row(): | |
input_pdf = gr.File(label="Upload PDF file") | |
max_length_slider = gr.Slider(512, 4096, value=2048, step=512, label="Max Length") | |
summarize_pdf_btn = gr.Button(value="Summarize PDF") | |
with gr.Row(): | |
output_pdf_summary = gr.Markdown(label="Summary", height=1000) | |
summarize_pdf_btn.click(summarize_pdf, inputs=[input_pdf, max_length_slider], outputs=output_pdf_summary) | |
# Tab for Arxiv Summarization | |
with gr.Tab(label="Arxiv Summarization"): | |
with gr.Column(): | |
arxiv_number = gr.Textbox(label="Enter arXiv number, i.e 2502.02523") | |
summarize_btn = gr.Button(value="Summarize arXiv Paper") | |
with gr.Column(): | |
output_summary = gr.Markdown(label="Summary", height=1000) | |
summarize_btn.click(summarize_arxiv_pdf, inputs=arxiv_number, outputs=output_summary) | |
# New Tab for Chat functionality | |
with gr.Tab(label="Chat with Assistant"): | |
gr.Markdown("### Chat with the Assistant") | |
with gr.Row(): | |
chat_input = gr.Textbox(placeholder="Type your message here...", label="Your Message") | |
send_button = gr.Button("Send") | |
# A Markdown to display the conversation history (or you could use gr.Chatbot) | |
chat_output = gr.Markdown(label="Chat Output", height=300) | |
# Maintain chat history as a list of [user, assistant] pairs | |
chat_history = gr.State([]) | |
# When the send button is clicked, update the chat history and get a response. | |
def update_chat(user_message, history): | |
# Append the new user message to history with an empty assistant response for now. | |
history = history or [] | |
history.append([user_message, ""]) | |
return history, history | |
def update_assistant_response(history): | |
# Get the last user message and call the chat function | |
user_message = history[-1][0] | |
response = chat_with_replit_sync(user_message, history[:-1]) | |
# Update the last entry with the assistant's response | |
history[-1][1] = response | |
# Format the conversation for display | |
formatted = "\n\n".join([f"**User:** {u}\n\n**Assistant:** {a}" for u, a in history]) | |
return history, formatted | |
send_button.click(update_chat, inputs=[chat_input, chat_history], outputs=[chat_history, chat_output]) | |
send_button.click(update_assistant_response, inputs=chat_history, outputs=[chat_history, chat_output]) | |
app.launch() | |