Spaces:
Running
Running
minor
Browse files- paper_chat_tab.py +16 -37
paper_chat_tab.py
CHANGED
|
@@ -73,7 +73,6 @@ def fetch_paper_info_neurips(paper_id):
|
|
| 73 |
else:
|
| 74 |
abstract = 'Abstract not found'
|
| 75 |
|
| 76 |
-
# Construct preamble
|
| 77 |
link = f"https://openreview.net/forum?id={paper_id}"
|
| 78 |
return title, author_list, f"**Abstract:** {abstract}\n\n[View on OpenReview]({link})"
|
| 79 |
|
|
@@ -110,12 +109,9 @@ def fetch_paper_content_arxiv(paper_id):
|
|
| 110 |
|
| 111 |
|
| 112 |
def fetch_paper_info_paperpage(paper_id_value):
|
| 113 |
-
# Extract paper_id from paper_page link or input
|
| 114 |
def extract_paper_id(input_string):
|
| 115 |
-
# Already in correct form?
|
| 116 |
if re.fullmatch(r'\d+\.\d+', input_string.strip()):
|
| 117 |
return input_string.strip()
|
| 118 |
-
# If URL
|
| 119 |
match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string)
|
| 120 |
if match:
|
| 121 |
return match.group(1)
|
|
@@ -141,7 +137,6 @@ def fetch_paper_info_paperpage(paper_id_value):
|
|
| 141 |
|
| 142 |
|
| 143 |
def fetch_paper_content_paperpage(paper_id_value):
|
| 144 |
-
# Extract paper_id
|
| 145 |
def extract_paper_id(input_string):
|
| 146 |
if re.fullmatch(r'\d+\.\d+', input_string.strip()):
|
| 147 |
return input_string.strip()
|
|
@@ -155,7 +150,6 @@ def fetch_paper_content_paperpage(paper_id_value):
|
|
| 155 |
return text
|
| 156 |
|
| 157 |
|
| 158 |
-
# Dictionary for paper sources
|
| 159 |
PAPER_SOURCES = {
|
| 160 |
"neurips": {
|
| 161 |
"fetch_info": fetch_paper_info_neurips,
|
|
@@ -170,16 +164,13 @@ PAPER_SOURCES = {
|
|
| 170 |
|
| 171 |
def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type,
|
| 172 |
provider_max_total_tokens):
|
| 173 |
-
# Define the function to handle the chat
|
| 174 |
def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value,
|
| 175 |
max_total_tokens):
|
| 176 |
provider_info = PROVIDERS[provider_name_value]
|
| 177 |
endpoint = provider_info['endpoint']
|
| 178 |
api_key_env_var = provider_info['api_key_env_var']
|
| 179 |
-
models = provider_info['models']
|
| 180 |
max_total_tokens = int(max_total_tokens)
|
| 181 |
|
| 182 |
-
# Load tokenizer
|
| 183 |
tokenizer_key = f"{provider_name_value}_{model_name_value}"
|
| 184 |
if tokenizer_key not in tokenizer_cache:
|
| 185 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
|
|
@@ -188,44 +179,36 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
|
|
| 188 |
else:
|
| 189 |
tokenizer = tokenizer_cache[tokenizer_key]
|
| 190 |
|
| 191 |
-
# Include the paper content as context
|
| 192 |
if paper_content_value:
|
| 193 |
context = f"The discussion is about the following paper:\n{paper_content_value}\n\n"
|
| 194 |
else:
|
| 195 |
context = ""
|
| 196 |
|
| 197 |
-
# Tokenize the context
|
| 198 |
context_tokens = tokenizer.encode(context)
|
| 199 |
context_token_length = len(context_tokens)
|
| 200 |
|
| 201 |
-
# Prepare the messages without context
|
| 202 |
messages = []
|
| 203 |
message_tokens_list = []
|
| 204 |
-
total_tokens = context_token_length
|
| 205 |
|
| 206 |
for user_msg, assistant_msg in history:
|
| 207 |
-
# Tokenize user message
|
| 208 |
user_tokens = tokenizer.encode(user_msg)
|
| 209 |
messages.append({"role": "user", "content": user_msg})
|
| 210 |
message_tokens_list.append(len(user_tokens))
|
| 211 |
total_tokens += len(user_tokens)
|
| 212 |
|
| 213 |
-
# Tokenize assistant message
|
| 214 |
if assistant_msg:
|
| 215 |
assistant_tokens = tokenizer.encode(assistant_msg)
|
| 216 |
messages.append({"role": "assistant", "content": assistant_msg})
|
| 217 |
message_tokens_list.append(len(assistant_tokens))
|
| 218 |
total_tokens += len(assistant_tokens)
|
| 219 |
|
| 220 |
-
# Tokenize the new user message
|
| 221 |
message_tokens = tokenizer.encode(message)
|
| 222 |
messages.append({"role": "user", "content": message})
|
| 223 |
message_tokens_list.append(len(message_tokens))
|
| 224 |
total_tokens += len(message_tokens)
|
| 225 |
|
| 226 |
-
# Check if total tokens exceed the maximum allowed tokens
|
| 227 |
if total_tokens > max_total_tokens:
|
| 228 |
-
# Attempt to truncate context
|
| 229 |
available_tokens = max_total_tokens - (total_tokens - context_token_length)
|
| 230 |
if available_tokens > 0:
|
| 231 |
truncated_context_tokens = context_tokens[:available_tokens]
|
|
@@ -237,24 +220,20 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
|
|
| 237 |
total_tokens -= context_token_length
|
| 238 |
context_token_length = 0
|
| 239 |
|
| 240 |
-
# Truncate message history if needed
|
| 241 |
while total_tokens > max_total_tokens and len(messages) > 1:
|
| 242 |
removed_message = messages.pop(0)
|
| 243 |
removed_tokens = message_tokens_list.pop(0)
|
| 244 |
total_tokens -= removed_tokens
|
| 245 |
|
| 246 |
-
# Rebuild the final messages
|
| 247 |
final_messages = []
|
| 248 |
if context:
|
| 249 |
final_messages.append({"role": "system", "content": f"{context}"})
|
| 250 |
final_messages.extend(messages)
|
| 251 |
|
| 252 |
-
# Use the provider's API key
|
| 253 |
api_key = hf_token_value or os.environ.get(api_key_env_var)
|
| 254 |
if not api_key:
|
| 255 |
raise ValueError("API token is not provided.")
|
| 256 |
|
| 257 |
-
# Initialize the OpenAI client
|
| 258 |
client = OpenAI(
|
| 259 |
base_url=endpoint,
|
| 260 |
api_key=api_key,
|
|
@@ -289,6 +268,7 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
|
|
| 289 |
|
| 290 |
|
| 291 |
def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
|
|
| 292 |
with gr.Row():
|
| 293 |
# Left column: Paper selection and display
|
| 294 |
with gr.Column(scale=1):
|
|
@@ -316,10 +296,10 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
| 316 |
)
|
| 317 |
select_paper_button = gr.Button("Load this paper")
|
| 318 |
|
| 319 |
-
# Paper info display
|
| 320 |
content = gr.HTML(value="", elem_id="paper_info_card")
|
| 321 |
|
| 322 |
-
# Right column: Provider and model selection
|
| 323 |
with gr.Column(scale=1, visible=False) as provider_section:
|
| 324 |
gr.Markdown("### LLM Provider and Model")
|
| 325 |
provider_names = list(PROVIDERS.keys())
|
|
@@ -354,7 +334,10 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
| 354 |
|
| 355 |
paper_content = gr.State()
|
| 356 |
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
| 358 |
chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content,
|
| 359 |
hf_token_input, default_type, default_max_total_tokens)
|
| 360 |
|
|
@@ -385,7 +368,6 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
| 385 |
)
|
| 386 |
|
| 387 |
def update_paper_info(paper_id_value, paper_from_value, selected_model, old_content):
|
| 388 |
-
# Use PAPER_SOURCES to fetch info
|
| 389 |
source_info = PAPER_SOURCES.get(paper_from_value, {})
|
| 390 |
fetch_info_fn = source_info.get("fetch_info")
|
| 391 |
fetch_pdf_fn = source_info.get("fetch_pdf")
|
|
@@ -401,7 +383,6 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
| 401 |
if text is None:
|
| 402 |
text = "Paper content could not be retrieved."
|
| 403 |
|
| 404 |
-
# Create a styled card for the paper info
|
| 405 |
card_html = f"""
|
| 406 |
<div style="border:1px solid #ccc; border-radius:6px; background:#f9f9f9; padding:15px; margin-bottom:10px;">
|
| 407 |
<center><h3 style="margin-top:0; text-decoration:underline;">You are talking with:</h3></center>
|
|
@@ -414,7 +395,6 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
| 414 |
return gr.update(value=card_html), text, []
|
| 415 |
|
| 416 |
def select_paper(paper_title):
|
| 417 |
-
# Find the corresponding paper_page from the title
|
| 418 |
for t, ppage in paper_choices:
|
| 419 |
if t == paper_title:
|
| 420 |
return ppage, "paper_page"
|
|
@@ -426,32 +406,34 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
| 426 |
outputs=[paper_id, paper_from]
|
| 427 |
)
|
| 428 |
|
| 429 |
-
# After updating paper_id, we update paper info
|
| 430 |
paper_id.change(
|
| 431 |
fn=update_paper_info,
|
| 432 |
inputs=[paper_id, paper_from, model_dropdown, content],
|
| 433 |
outputs=[content, paper_content, chatbot]
|
| 434 |
)
|
| 435 |
|
| 436 |
-
# Function to toggle visibility of the right column based on paper_id
|
| 437 |
def toggle_provider_visibility(paper_id_value):
|
| 438 |
if paper_id_value and paper_id_value.strip():
|
| 439 |
return gr.update(visible=True)
|
| 440 |
else:
|
| 441 |
return gr.update(visible=False)
|
| 442 |
|
| 443 |
-
#
|
| 444 |
paper_id.change(
|
| 445 |
fn=toggle_provider_visibility,
|
| 446 |
inputs=[paper_id],
|
| 447 |
outputs=[provider_section]
|
| 448 |
)
|
| 449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
def main():
|
| 452 |
-
"""
|
| 453 |
-
Launches the Gradio app.
|
| 454 |
-
"""
|
| 455 |
with gr.Blocks(css_paths="style.css") as demo:
|
| 456 |
paper_id = gr.Textbox(label="Paper ID", value="")
|
| 457 |
paper_from = gr.Radio(
|
|
@@ -460,9 +442,6 @@ def main():
|
|
| 460 |
value="neurips"
|
| 461 |
)
|
| 462 |
|
| 463 |
-
# Build the paper chat tab
|
| 464 |
-
dummy_calendar = gr.State(datetime.now().strftime("%Y-%m-%d"))
|
| 465 |
-
|
| 466 |
class MockPaperCentral:
|
| 467 |
def __init__(self):
|
| 468 |
import pandas as pd
|
|
|
|
| 73 |
else:
|
| 74 |
abstract = 'Abstract not found'
|
| 75 |
|
|
|
|
| 76 |
link = f"https://openreview.net/forum?id={paper_id}"
|
| 77 |
return title, author_list, f"**Abstract:** {abstract}\n\n[View on OpenReview]({link})"
|
| 78 |
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
def fetch_paper_info_paperpage(paper_id_value):
|
|
|
|
| 112 |
def extract_paper_id(input_string):
|
|
|
|
| 113 |
if re.fullmatch(r'\d+\.\d+', input_string.strip()):
|
| 114 |
return input_string.strip()
|
|
|
|
| 115 |
match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string)
|
| 116 |
if match:
|
| 117 |
return match.group(1)
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
def fetch_paper_content_paperpage(paper_id_value):
|
|
|
|
| 140 |
def extract_paper_id(input_string):
|
| 141 |
if re.fullmatch(r'\d+\.\d+', input_string.strip()):
|
| 142 |
return input_string.strip()
|
|
|
|
| 150 |
return text
|
| 151 |
|
| 152 |
|
|
|
|
| 153 |
PAPER_SOURCES = {
|
| 154 |
"neurips": {
|
| 155 |
"fetch_info": fetch_paper_info_neurips,
|
|
|
|
| 164 |
|
| 165 |
def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type,
|
| 166 |
provider_max_total_tokens):
|
|
|
|
| 167 |
def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value,
|
| 168 |
max_total_tokens):
|
| 169 |
provider_info = PROVIDERS[provider_name_value]
|
| 170 |
endpoint = provider_info['endpoint']
|
| 171 |
api_key_env_var = provider_info['api_key_env_var']
|
|
|
|
| 172 |
max_total_tokens = int(max_total_tokens)
|
| 173 |
|
|
|
|
| 174 |
tokenizer_key = f"{provider_name_value}_{model_name_value}"
|
| 175 |
if tokenizer_key not in tokenizer_cache:
|
| 176 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
|
|
|
|
| 179 |
else:
|
| 180 |
tokenizer = tokenizer_cache[tokenizer_key]
|
| 181 |
|
|
|
|
| 182 |
if paper_content_value:
|
| 183 |
context = f"The discussion is about the following paper:\n{paper_content_value}\n\n"
|
| 184 |
else:
|
| 185 |
context = ""
|
| 186 |
|
|
|
|
| 187 |
context_tokens = tokenizer.encode(context)
|
| 188 |
context_token_length = len(context_tokens)
|
| 189 |
|
|
|
|
| 190 |
messages = []
|
| 191 |
message_tokens_list = []
|
| 192 |
+
total_tokens = context_token_length
|
| 193 |
|
| 194 |
for user_msg, assistant_msg in history:
|
|
|
|
| 195 |
user_tokens = tokenizer.encode(user_msg)
|
| 196 |
messages.append({"role": "user", "content": user_msg})
|
| 197 |
message_tokens_list.append(len(user_tokens))
|
| 198 |
total_tokens += len(user_tokens)
|
| 199 |
|
|
|
|
| 200 |
if assistant_msg:
|
| 201 |
assistant_tokens = tokenizer.encode(assistant_msg)
|
| 202 |
messages.append({"role": "assistant", "content": assistant_msg})
|
| 203 |
message_tokens_list.append(len(assistant_tokens))
|
| 204 |
total_tokens += len(assistant_tokens)
|
| 205 |
|
|
|
|
| 206 |
message_tokens = tokenizer.encode(message)
|
| 207 |
messages.append({"role": "user", "content": message})
|
| 208 |
message_tokens_list.append(len(message_tokens))
|
| 209 |
total_tokens += len(message_tokens)
|
| 210 |
|
|
|
|
| 211 |
if total_tokens > max_total_tokens:
|
|
|
|
| 212 |
available_tokens = max_total_tokens - (total_tokens - context_token_length)
|
| 213 |
if available_tokens > 0:
|
| 214 |
truncated_context_tokens = context_tokens[:available_tokens]
|
|
|
|
| 220 |
total_tokens -= context_token_length
|
| 221 |
context_token_length = 0
|
| 222 |
|
|
|
|
| 223 |
while total_tokens > max_total_tokens and len(messages) > 1:
|
| 224 |
removed_message = messages.pop(0)
|
| 225 |
removed_tokens = message_tokens_list.pop(0)
|
| 226 |
total_tokens -= removed_tokens
|
| 227 |
|
|
|
|
| 228 |
final_messages = []
|
| 229 |
if context:
|
| 230 |
final_messages.append({"role": "system", "content": f"{context}"})
|
| 231 |
final_messages.extend(messages)
|
| 232 |
|
|
|
|
| 233 |
api_key = hf_token_value or os.environ.get(api_key_env_var)
|
| 234 |
if not api_key:
|
| 235 |
raise ValueError("API token is not provided.")
|
| 236 |
|
|
|
|
| 237 |
client = OpenAI(
|
| 238 |
base_url=endpoint,
|
| 239 |
api_key=api_key,
|
|
|
|
| 268 |
|
| 269 |
|
| 270 |
def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
| 271 |
+
# First row with two columns
|
| 272 |
with gr.Row():
|
| 273 |
# Left column: Paper selection and display
|
| 274 |
with gr.Column(scale=1):
|
|
|
|
| 296 |
)
|
| 297 |
select_paper_button = gr.Button("Load this paper")
|
| 298 |
|
| 299 |
+
# Paper info display
|
| 300 |
content = gr.HTML(value="", elem_id="paper_info_card")
|
| 301 |
|
| 302 |
+
# Right column: Provider and model selection
|
| 303 |
with gr.Column(scale=1, visible=False) as provider_section:
|
| 304 |
gr.Markdown("### LLM Provider and Model")
|
| 305 |
provider_names = list(PROVIDERS.keys())
|
|
|
|
| 334 |
|
| 335 |
paper_content = gr.State()
|
| 336 |
|
| 337 |
+
# Now a new row, full width, for the chat
|
| 338 |
+
with gr.Row(visible=False) as chat_row:
|
| 339 |
+
with gr.Column():
|
| 340 |
+
# Create chat interface below the two columns
|
| 341 |
chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content,
|
| 342 |
hf_token_input, default_type, default_max_total_tokens)
|
| 343 |
|
|
|
|
| 368 |
)
|
| 369 |
|
| 370 |
def update_paper_info(paper_id_value, paper_from_value, selected_model, old_content):
|
|
|
|
| 371 |
source_info = PAPER_SOURCES.get(paper_from_value, {})
|
| 372 |
fetch_info_fn = source_info.get("fetch_info")
|
| 373 |
fetch_pdf_fn = source_info.get("fetch_pdf")
|
|
|
|
| 383 |
if text is None:
|
| 384 |
text = "Paper content could not be retrieved."
|
| 385 |
|
|
|
|
| 386 |
card_html = f"""
|
| 387 |
<div style="border:1px solid #ccc; border-radius:6px; background:#f9f9f9; padding:15px; margin-bottom:10px;">
|
| 388 |
<center><h3 style="margin-top:0; text-decoration:underline;">You are talking with:</h3></center>
|
|
|
|
| 395 |
return gr.update(value=card_html), text, []
|
| 396 |
|
| 397 |
def select_paper(paper_title):
|
|
|
|
| 398 |
for t, ppage in paper_choices:
|
| 399 |
if t == paper_title:
|
| 400 |
return ppage, "paper_page"
|
|
|
|
| 406 |
outputs=[paper_id, paper_from]
|
| 407 |
)
|
| 408 |
|
|
|
|
| 409 |
paper_id.change(
|
| 410 |
fn=update_paper_info,
|
| 411 |
inputs=[paper_id, paper_from, model_dropdown, content],
|
| 412 |
outputs=[content, paper_content, chatbot]
|
| 413 |
)
|
| 414 |
|
|
|
|
| 415 |
def toggle_provider_visibility(paper_id_value):
|
| 416 |
if paper_id_value and paper_id_value.strip():
|
| 417 |
return gr.update(visible=True)
|
| 418 |
else:
|
| 419 |
return gr.update(visible=False)
|
| 420 |
|
| 421 |
+
# Toggle provider section visibility
|
| 422 |
paper_id.change(
|
| 423 |
fn=toggle_provider_visibility,
|
| 424 |
inputs=[paper_id],
|
| 425 |
outputs=[provider_section]
|
| 426 |
)
|
| 427 |
|
| 428 |
+
# Toggle chat row visibility
|
| 429 |
+
paper_id.change(
|
| 430 |
+
fn=toggle_provider_visibility,
|
| 431 |
+
inputs=[paper_id],
|
| 432 |
+
outputs=[chat_row]
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
|
| 436 |
def main():
|
|
|
|
|
|
|
|
|
|
| 437 |
with gr.Blocks(css_paths="style.css") as demo:
|
| 438 |
paper_id = gr.Textbox(label="Paper ID", value="")
|
| 439 |
paper_from = gr.Radio(
|
|
|
|
| 442 |
value="neurips"
|
| 443 |
)
|
| 444 |
|
|
|
|
|
|
|
|
|
|
| 445 |
class MockPaperCentral:
|
| 446 |
def __init__(self):
|
| 447 |
import pandas as pd
|