Spaces:
Running
Running
| import google.generativeai as genai | |
| import gradio as gr | |
| from PyPDF2 import PdfReader | |
| from bs4 import BeautifulSoup | |
| import openai | |
| import traceback | |
| import requests | |
| from io import BytesIO | |
| from transformers import AutoTokenizer | |
| import json | |
| from datetime import datetime | |
| import os | |
| from openai import OpenAI | |
| import re | |
| # Cache for tokenizers to avoid reloading | |
| tokenizer_cache = {} | |
| # Global variables for providers | |
| PROVIDERS = { | |
| "Gemini": { | |
| "name": "Gemini", | |
| "logo": "https://www.gstatic.com/lamda/images/gemini_thumbnail_c362e5eadc46ca9f617e2.png", | |
| "endpoint": "https://example-gemini-endpoint", # not need | |
| # Not necessarily needed for Gemini since we use google.generativeai directly | |
| "api_key_env_var": "GEMINI_API_KEY", # If using env vars for key storage | |
| "models": [ | |
| "gemini-2.0-flash-exp", | |
| "gemini-1.5-flash", | |
| ], | |
| "type": "tuples", | |
| "max_total_tokens": "50000", | |
| }, | |
| "SambaNova": { | |
| "name": "SambaNova", | |
| "logo": "https://venturebeat.com/wp-content/uploads/2020/02/SambaNovaLogo_H_F.jpg", | |
| "endpoint": "https://api.sambanova.ai/v1/", | |
| "api_key_env_var": "SAMBANOVA_API_KEY", | |
| "models": [ | |
| "Meta-Llama-3.1-70B-Instruct", | |
| "Meta-Llama-3.3-70B-Instruct", | |
| ], | |
| "type": "tuples", | |
| "max_total_tokens": "50000", | |
| }, | |
| "Hyperbolic": { | |
| "name": "hyperbolic", | |
| "logo": "https://www.nftgators.com/wp-content/uploads/2024/07/Hyperbolic.jpg", | |
| "endpoint": "https://api.hyperbolic.xyz/v1", | |
| "api_key_env_var": "HYPERBOLIC_API_KEY", | |
| "models": [ | |
| "meta-llama/Llama-3.3-70B-Instruct", | |
| "meta-llama/Meta-Llama-3.1-405B-Instruct", | |
| ], | |
| "type": "tuples", | |
| "max_total_tokens": "50000", | |
| }, | |
| } | |
| # Functions for paper fetching | |
| def fetch_paper_info_neurips(paper_id): | |
| url = f"https://openreview.net/forum?id={paper_id}" | |
| response = requests.get(url) | |
| if response.status_code != 200: | |
| return None, None, None | |
| html_content = response.content | |
| soup = BeautifulSoup(html_content, 'html.parser') | |
| # Extract title | |
| title_tag = soup.find('h2', class_='citation_title') | |
| title = title_tag.get_text(strip=True) if title_tag else 'Title not found' | |
| # Extract authors | |
| authors = [] | |
| author_div = soup.find('div', class_='forum-authors') | |
| if author_div: | |
| author_tags = author_div.find_all('a') | |
| authors = [tag.get_text(strip=True) for tag in author_tags] | |
| author_list = ', '.join(authors) if authors else 'Authors not found' | |
| # Extract abstract | |
| abstract_div = soup.find('strong', text='Abstract:') | |
| if abstract_div: | |
| abstract_paragraph = abstract_div.find_next_sibling('div') | |
| abstract = abstract_paragraph.get_text(strip=True) if abstract_paragraph else 'Abstract not found' | |
| else: | |
| abstract = 'Abstract not found' | |
| link = f"https://openreview.net/forum?id={paper_id}" | |
| return title, author_list, f"**Abstract:** {abstract}\n\n[View on OpenReview]({link})" | |
| def fetch_paper_content_neurips(paper_id): | |
| try: | |
| url = f"https://openreview.net/pdf?id={paper_id}" | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| pdf_content = BytesIO(response.content) | |
| reader = PdfReader(pdf_content) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() | |
| return text | |
| except: | |
| return None | |
| def fetch_paper_content_arxiv(paper_id): | |
| try: | |
| url = f"https://arxiv.org/pdf/{paper_id}.pdf" | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| pdf_content = BytesIO(response.content) | |
| reader = PdfReader(pdf_content) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() | |
| return text | |
| except Exception as e: | |
| print(f"Error fetching paper content: {e}") | |
| return None | |
| def fetch_paper_info_paperpage(paper_id_value): | |
| def extract_paper_id(input_string): | |
| if re.fullmatch(r'\d+\.\d+', input_string.strip()): | |
| return input_string.strip() | |
| match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string) | |
| if match: | |
| return match.group(1) | |
| return input_string.strip() | |
| paper_id_value = extract_paper_id(paper_id_value) | |
| url = f"https://huggingface.co/api/papers/{paper_id_value}?field=comments" | |
| response = requests.get(url) | |
| if response.status_code != 200: | |
| return None, None, None | |
| paper_info = response.json() | |
| title = paper_info.get('title', 'No Title') | |
| authors_list = [author.get('name', 'Unknown') for author in paper_info.get('authors', [])] | |
| authors = ', '.join(authors_list) | |
| summary = paper_info.get('summary', 'No Summary') | |
| num_comments = len(paper_info.get('comments', [])) | |
| num_upvotes = paper_info.get('upvotes', 0) | |
| link = f"https://huggingface.co/papers/{paper_id_value}" | |
| details = f"{summary}<br/>👍{num_comments} 💬{num_upvotes}<br/> <a href='{link}' " \ | |
| f"target='_blank'>View on 🤗 hugging face</a>" | |
| return title, authors, details | |
| def fetch_paper_content_paperpage(paper_id_value): | |
| def extract_paper_id(input_string): | |
| if re.fullmatch(r'\d+\.\d+', input_string.strip()): | |
| return input_string.strip() | |
| match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string) | |
| if match: | |
| return match.group(1) | |
| return input_string.strip() | |
| paper_id_value = extract_paper_id(paper_id_value) | |
| text = fetch_paper_content_arxiv(paper_id_value) | |
| return text | |
| PAPER_SOURCES = { | |
| "neurips": { | |
| "fetch_info": fetch_paper_info_neurips, | |
| "fetch_pdf": fetch_paper_content_neurips | |
| }, | |
| "paper_page": { | |
| "fetch_info": fetch_paper_info_paperpage, | |
| "fetch_pdf": fetch_paper_content_paperpage | |
| } | |
| } | |
| def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type, | |
| provider_max_total_tokens): | |
| def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value, | |
| max_total_tokens): | |
| provider_info = PROVIDERS[provider_name_value] | |
| endpoint = provider_info['endpoint'] | |
| api_key_env_var = provider_info['api_key_env_var'] | |
| max_total_tokens = int(max_total_tokens) | |
| tokenizer_key = f"{provider_name_value}_{model_name_value}" | |
| if tokenizer_key not in tokenizer_cache: | |
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", | |
| token=os.environ.get("HF_TOKEN")) | |
| tokenizer_cache[tokenizer_key] = tokenizer | |
| else: | |
| tokenizer = tokenizer_cache[tokenizer_key] | |
| if paper_content_value: | |
| context = f"The discussion is about the following paper:\n{paper_content_value}\n\n" | |
| else: | |
| context = "" | |
| context_tokens = tokenizer.encode(context) | |
| context_token_length = len(context_tokens) | |
| messages = [] | |
| message_tokens_list = [] | |
| total_tokens = context_token_length | |
| # Reconstruct the conversation from history and current user message | |
| for user_msg, assistant_msg in history: | |
| user_tokens = tokenizer.encode(user_msg) | |
| messages.append({"role": "user", "content": user_msg}) | |
| message_tokens_list.append(len(user_tokens)) | |
| total_tokens += len(user_tokens) | |
| if assistant_msg: | |
| assistant_tokens = tokenizer.encode(assistant_msg) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| message_tokens_list.append(len(assistant_tokens)) | |
| total_tokens += len(assistant_tokens) | |
| message_tokens = tokenizer.encode(message) | |
| messages.append({"role": "user", "content": message}) | |
| message_tokens_list.append(len(message_tokens)) | |
| total_tokens += len(message_tokens) | |
| # Token truncation logic | |
| if total_tokens > max_total_tokens: | |
| available_tokens = max_total_tokens - (total_tokens - context_token_length) | |
| if available_tokens > 0: | |
| truncated_context_tokens = context_tokens[:available_tokens] | |
| context = tokenizer.decode(truncated_context_tokens) | |
| context_token_length = available_tokens | |
| total_tokens = total_tokens - len(context_tokens) + context_token_length | |
| else: | |
| context = "" | |
| total_tokens -= context_token_length | |
| context_token_length = 0 | |
| while total_tokens > max_total_tokens and len(messages) > 1: | |
| removed_message = messages.pop(0) | |
| removed_tokens = message_tokens_list.pop(0) | |
| total_tokens -= removed_tokens | |
| final_messages = [] | |
| if context: | |
| final_messages.append( | |
| {"role": "system" if not provider_name_value == "Gemini" else "user", "content": f"{context}"}) | |
| final_messages.extend(messages) | |
| api_key = hf_token_value or os.environ.get(api_key_env_var) | |
| if not api_key: | |
| raise ValueError("API token is not provided.") | |
| # Gemini logic | |
| if provider_name_value == "Gemini": | |
| import google.generativeai as genai | |
| genai.configure(api_key=api_key) | |
| # According to the docs, model should be instantiated with full model name, e.g. "models/gemini-1.5-flash" | |
| # Ensure your PROVIDERS dict sets the model_name_value accordingly (e.g. "models/gemini-1.5-flash") | |
| model = genai.GenerativeModel(model_name=model_name_value) | |
| # Convert final_messages into Gemini's format: | |
| # Gemini expects a list of messages: [{"role": "user"/"assistant"/"system", "parts": ["..."]}, ...] | |
| gemini_messages = [] | |
| for m in final_messages: | |
| gemini_messages.append({"role": m["role"], "parts": [m["content"]]}) | |
| # Now call generate_content with stream=True | |
| try: | |
| response = model.generate_content(gemini_messages, stream=True) | |
| response_text = "" | |
| for chunk in response: | |
| if chunk.text: | |
| response_text += chunk.text | |
| yield response_text | |
| except Exception as ex: | |
| yield f"Error calling Gemini: {ex}" | |
| else: | |
| # Default OpenAI-compatible logic | |
| from openai import OpenAI | |
| import openai | |
| import json | |
| client = OpenAI( | |
| base_url=endpoint, | |
| api_key=api_key, | |
| ) | |
| try: | |
| completion = client.chat.completions.create( | |
| model=model_name_value, | |
| messages=final_messages, | |
| stream=True, | |
| ) | |
| response_text = "" | |
| for chunk in completion: | |
| delta = chunk.choices[0].delta.content or "" | |
| response_text += delta | |
| yield response_text | |
| except json.JSONDecodeError as e: | |
| yield f"JSON decoding error: {e.msg}" | |
| except openai.OpenAIError as openai_err: | |
| yield f"OpenAI error: {openai_err}" | |
| except Exception as ex: | |
| yield f"Unexpected error: {ex}" | |
| chatbot = gr.Chatbot(label="Chatbot", scale=1, height=800, autoscroll=True) | |
| chat_interface = gr.ChatInterface( | |
| fn=get_fn, | |
| chatbot=chatbot, | |
| additional_inputs=[paper_content, hf_token_input, provider_dropdown, model_dropdown, provider_max_total_tokens], | |
| type="tuples", | |
| ) | |
| return chat_interface, chatbot | |
| def paper_chat_tab(paper_id, paper_from, paper_central_df): | |
| # A top-level button to "Chat with another paper" (visible only if paper_id is set) | |
| # We'll place it above everything | |
| chat_another_button = gr.Button("Chat with another paper", variant="primary", visible=False) | |
| # First row with two columns | |
| with gr.Row(): | |
| # Left column: Paper selection and display | |
| with gr.Column(scale=1): | |
| todays_date = datetime.today().strftime('%Y-%m-%d') | |
| # Filter papers for today's date and having a paper_page | |
| selectable_papers = paper_central_df.df_prettified | |
| selectable_papers = selectable_papers[ | |
| selectable_papers['paper_page'].notna() & | |
| (selectable_papers['paper_page'] != "") & | |
| (selectable_papers['date'] == todays_date) | |
| ] | |
| paper_choices = [(row['title'], row['paper_page']) for _, row in selectable_papers.iterrows()] | |
| paper_choices = sorted(paper_choices, key=lambda x: x[0]) | |
| if not paper_choices: | |
| paper_choices = [("No available papers for today", "")] | |
| paper_select = gr.Dropdown( | |
| label="Select a paper to chat with: (from today's 🤗 hugging face paper page)", | |
| choices=[p[0] for p in paper_choices], | |
| value=paper_choices[0][0] if paper_choices else None | |
| ) | |
| # Add a textbox for user to enter a paper_id (arxiv_id) | |
| paper_id_input = gr.Textbox( | |
| label="Or enter a 🤗 paper_id directly", | |
| placeholder="e.g. 1234.56789" | |
| ) | |
| select_paper_button = gr.Button("Load this paper") | |
| # Paper info display | |
| content = gr.HTML(value="", elem_id="paper_info_card") | |
| # Right column: Provider and model selection | |
| with gr.Column(scale=1, visible=False) as provider_section: | |
| gr.Markdown("### LLM Provider and Model") | |
| provider_names = list(PROVIDERS.keys()) | |
| default_provider = provider_names[0] | |
| default_type = gr.State(value=PROVIDERS[default_provider]["type"]) | |
| default_max_total_tokens = gr.State(value=PROVIDERS[default_provider]["max_total_tokens"]) | |
| provider_dropdown = gr.Dropdown( | |
| label="Select Provider", | |
| choices=provider_names, | |
| value=default_provider | |
| ) | |
| hf_token_input = gr.Textbox( | |
| label=f"Enter your {default_provider} API token (optional)", | |
| type="password", | |
| placeholder=f"Enter your {default_provider} API token to avoid rate limits" | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| label="Select Model", | |
| choices=PROVIDERS[default_provider]['models'], | |
| value=PROVIDERS[default_provider]['models'][0] | |
| ) | |
| logo_html = gr.HTML( | |
| value=f'<img src="{PROVIDERS[default_provider]["logo"]}" width="100px" />' | |
| ) | |
| note_markdown = gr.Markdown(f"**Note:** This model is supported by {default_provider}.") | |
| paper_content = gr.State() | |
| # Now a new row, full width, for the chat | |
| with gr.Row(visible=False) as chat_row: | |
| with gr.Column(): | |
| # Create chat interface below the two columns | |
| chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content, | |
| hf_token_input, default_type, default_max_total_tokens) | |
| def update_provider(selected_provider): | |
| provider_info = PROVIDERS[selected_provider] | |
| models = provider_info['models'] | |
| logo_url = provider_info['logo'] | |
| max_total_tokens = provider_info['max_total_tokens'] | |
| model_dropdown_choices = gr.update(choices=models, value=models[0]) | |
| logo_html_content = f'<img src="{logo_url}" width="100px" />' | |
| logo_html_update = gr.update(value=logo_html_content) | |
| note_markdown_update = gr.update(value=f"**Note:** This model is supported by {selected_provider}.") | |
| hf_token_input_update = gr.update( | |
| label=f"Enter your {selected_provider} API token (optional)", | |
| placeholder=f"Enter your {selected_provider} API token to avoid rate limits" | |
| ) | |
| chatbot_reset = [] | |
| return model_dropdown_choices, logo_html_update, note_markdown_update, hf_token_input_update, provider_info[ | |
| 'type'], max_total_tokens, chatbot_reset | |
| provider_dropdown.change( | |
| fn=update_provider, | |
| inputs=provider_dropdown, | |
| outputs=[model_dropdown, logo_html, note_markdown, hf_token_input, default_type, default_max_total_tokens, | |
| chatbot], | |
| queue=False | |
| ) | |
| def update_paper_info(paper_id_value, paper_from_value, selected_model, old_content): | |
| source_info = PAPER_SOURCES.get(paper_from_value, {}) | |
| fetch_info_fn = source_info.get("fetch_info") | |
| fetch_pdf_fn = source_info.get("fetch_pdf") | |
| if not fetch_info_fn or not fetch_pdf_fn: | |
| return gr.update(value="<div>No information available.</div>"), None, [] | |
| title, authors, details = fetch_info_fn(paper_id_value) | |
| if title is None and authors is None and details is None: | |
| return gr.update(value="<div>No information could be retrieved.</div>"), None, [] | |
| text = fetch_pdf_fn(paper_id_value) | |
| if text is None: | |
| text = "Paper content could not be retrieved." | |
| card_html = f""" | |
| <div style="border:1px solid #ccc; border-radius:6px; background:#f9f9f9; padding:15px; margin-bottom:10px;"> | |
| <center><h3 style="margin-top:0; text-decoration:underline;">You are talking with:</h3></center> | |
| <h3>{title}</h3> | |
| <p><strong>Authors:</strong> {authors}</p> | |
| <p>{details}</p> | |
| </div> | |
| """ | |
| return gr.update(value=card_html), text, [] | |
| def select_paper(paper_title, paper_id_val): | |
| # If user provided a paper_id_val (arxiv_id), use that | |
| if paper_id_val and paper_id_val.strip(): | |
| # Check if it exists in df as a paper with paper_page not None | |
| df = paper_central_df.df_raw | |
| # We assume `arxiv_id` column exists in df (the user requested checking arxiv_id) | |
| # If not present, you must ensure `paper_central_df` has `arxiv_id` column. | |
| if 'arxiv_id' not in df.columns: | |
| return gr.update(value="<div>arxiv_id column not found in dataset</div>"), None | |
| found = df[ | |
| (df['arxiv_id'] == paper_id_val.strip()) & | |
| df['paper_page'].notna() & (df['paper_page'] != "") | |
| ] | |
| if len(found) > 0: | |
| # We found a matching paper | |
| return paper_id_val.strip(), "paper_page" | |
| else: | |
| # Not found, show error in content | |
| # We can't directly show error from here. We'll return something that doesn't update states and rely on error message | |
| # Let's return empty paper_id and paper_from but we must also show error in content after this call | |
| return "", "" | |
| else: | |
| # fallback to dropdown selection | |
| for t, ppage in paper_choices: | |
| if t == paper_title: | |
| return ppage, "paper_page" | |
| return "", "" | |
| select_paper_button.click( | |
| fn=select_paper, | |
| inputs=[paper_select, paper_id_input], | |
| outputs=[paper_id, paper_from] | |
| ) | |
| # After the paper_id/paper_from are set, we update paper info | |
| paper_id_update = paper_id.change( | |
| fn=update_paper_info, | |
| inputs=[paper_id, paper_from, model_dropdown, content], | |
| outputs=[content, paper_content, chatbot] | |
| ) | |
| def toggle_provider_visibility(paper_id_value): | |
| if paper_id_value and paper_id_value.strip(): | |
| return gr.update(visible=True) | |
| else: | |
| return gr.update(visible=False) | |
| paper_id.change( | |
| fn=toggle_provider_visibility, | |
| inputs=[paper_id], | |
| outputs=[provider_section] | |
| ) | |
| paper_id.change( | |
| fn=toggle_provider_visibility, | |
| inputs=[paper_id], | |
| outputs=[chat_row] | |
| ) | |
| # Show/hide the "Chat with another paper" button | |
| # If paper_id is set, show it. If not, hide it. | |
| def toggle_chat_another_button(paper_id_value): | |
| if paper_id_value and paper_id_value.strip(): | |
| return gr.update(visible=True) | |
| else: | |
| return gr.update(visible=False) | |
| paper_id.change( | |
| fn=toggle_chat_another_button, | |
| inputs=[paper_id], | |
| outputs=[chat_another_button] | |
| ) | |
| # Button action to reset paper_id to None | |
| def reset_paper_id(): | |
| # reset paper_id to "" | |
| return "", "neurips", gr.update(value="<div></div>") | |
| # When this button is clicked, we reset the paper_id and content | |
| chat_another_button.click( | |
| fn=reset_paper_id, | |
| outputs=[paper_id, paper_from, content] | |
| ) | |
| # If user tried an invalid paper_id_input, no error was shown yet: | |
| # Actually we can show error message if no paper selected by updating after select_paper_button | |
| # The select_paper returns paper_id/paper_from. If empty means error: | |
| def check_paper_id_error(p_id, p_from): | |
| # If p_id is empty after clicking load, show error message | |
| if not p_id: | |
| return gr.update(value="<div style='color:red;'>No valid paper found for the given input.</div>") | |
| else: | |
| return gr.update() | |
| select_paper_button.click( | |
| fn=check_paper_id_error, | |
| inputs=[paper_id, paper_from], | |
| outputs=[content], | |
| queue=False | |
| ) | |
| def main(): | |
| with gr.Blocks(css_paths="style.css") as demo: | |
| paper_id = gr.Textbox(label="Paper ID", value="") | |
| paper_from = gr.Radio( | |
| label="Paper Source", | |
| choices=["neurips", "paper_page"], | |
| value="neurips" | |
| ) | |
| class MockPaperCentral: | |
| def __init__(self): | |
| import pandas as pd | |
| data = { | |
| 'date': [datetime.today().strftime('%Y-%m-%d')], | |
| 'paper_page': ['1234.56789'], | |
| 'arxiv_id': ['1234.56789'], # adding arxiv_id column as user requested | |
| 'title': ['An Example Paper'] | |
| } | |
| self.df_prettified = pd.DataFrame(data) | |
| paper_central_df = MockPaperCentral() | |
| paper_chat_tab(paper_id, paper_from, paper_central_df) | |
| demo.launch(ssr_mode=False) | |
| if __name__ == "__main__": | |
| main() | |