Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Kartheik Iyer
commited on
Commit
·
18e51e3
1
Parent(s):
036767e
update dataset and security fixes
Browse files- app_gradio.py +75 -45
app_gradio.py
CHANGED
|
@@ -34,7 +34,7 @@ from typing import List, Literal
|
|
| 34 |
|
| 35 |
from nltk.corpus import stopwords
|
| 36 |
import nltk
|
| 37 |
-
from openai import OpenAI
|
| 38 |
# import anthropic
|
| 39 |
import cohere
|
| 40 |
import faiss
|
|
@@ -64,6 +64,12 @@ embed_model = "text-embedding-3-small"
|
|
| 64 |
embeddings = OpenAIEmbeddings(model = embed_model, api_key = openai_key)
|
| 65 |
nlp = load_nlp()
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
def get_keywords(text, nlp=nlp):
|
| 69 |
result = []
|
|
@@ -77,8 +83,12 @@ def get_keywords(text, nlp=nlp):
|
|
| 77 |
return result
|
| 78 |
|
| 79 |
def load_arxiv_corpus():
|
| 80 |
-
arxiv_corpus = load_from_disk('data/')
|
| 81 |
-
arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
print('loading arxiv corpus from disk')
|
| 83 |
return arxiv_corpus
|
| 84 |
|
|
@@ -344,6 +354,23 @@ def guess_question_type(query: str):
|
|
| 344 |
messages = [("system",question_categorization_prompt,),("human", query),]
|
| 345 |
return gen_client.invoke(messages).content
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
class OverallConsensusEvaluation(BaseModel):
|
| 348 |
rewritten_statement: str = Field(
|
| 349 |
...,
|
|
@@ -459,48 +486,51 @@ def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type,
|
|
| 459 |
search_text_list = ['rooting around in the paper pile...','looking for clarity...','scanning the event horizon...','peering into the abyss...','potatoes power this ongoing search...']
|
| 460 |
gen_text_list = ['making the LLM talk to the papers...','invoking arcane rituals...','gone to library, please wait...','is there really an answer to this...']
|
| 461 |
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
ec.
|
| 468 |
-
ec.
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
|
|
|
|
|
|
|
|
|
| 504 |
|
| 505 |
def create_interface():
|
| 506 |
custom_css = """
|
|
|
|
| 34 |
|
| 35 |
from nltk.corpus import stopwords
|
| 36 |
import nltk
|
| 37 |
+
from openai import OpenAI, moderations
|
| 38 |
# import anthropic
|
| 39 |
import cohere
|
| 40 |
import faiss
|
|
|
|
| 64 |
embeddings = OpenAIEmbeddings(model = embed_model, api_key = openai_key)
|
| 65 |
nlp = load_nlp()
|
| 66 |
|
| 67 |
+
def check_mod(query):
|
| 68 |
+
mod_report = moderations.create(input=query)
|
| 69 |
+
for i in mod_report.results[0].categories:
|
| 70 |
+
if i[1] == True:
|
| 71 |
+
return True
|
| 72 |
+
return False
|
| 73 |
|
| 74 |
def get_keywords(text, nlp=nlp):
|
| 75 |
result = []
|
|
|
|
| 83 |
return result
|
| 84 |
|
| 85 |
def load_arxiv_corpus():
|
| 86 |
+
# arxiv_corpus = load_from_disk('data/')
|
| 87 |
+
# arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss')
|
| 88 |
+
|
| 89 |
+
# keeping it up to date with the dataset
|
| 90 |
+
arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data', split='train')
|
| 91 |
+
arxiv_corpus.add_faiss_index(column='embed')
|
| 92 |
print('loading arxiv corpus from disk')
|
| 93 |
return arxiv_corpus
|
| 94 |
|
|
|
|
| 354 |
messages = [("system",question_categorization_prompt,),("human", query),]
|
| 355 |
return gen_client.invoke(messages).content
|
| 356 |
|
| 357 |
+
def log_to_gist(strings):
|
| 358 |
+
# Adding query logs to prevent and account for possible malicious use.
|
| 359 |
+
# Logs will be deleted periodically if not needed.
|
| 360 |
+
github_token = os.environ['github_token']
|
| 361 |
+
gist_id = os.environ['gist_id']
|
| 362 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 363 |
+
content = f"\n{timestamp}: {' '.join(strings)}\n"
|
| 364 |
+
headers = {'Authorization': f'token {github_token}','Accept': 'application/vnd.github.v3+json'}
|
| 365 |
+
response = requests.get(f'https://api.github.com/gists/{gist_id}', headers=headers)
|
| 366 |
+
if response.status_code == 200:
|
| 367 |
+
existing_content = response.json()['files']['log.txt']['content']
|
| 368 |
+
content = existing_content + content
|
| 369 |
+
data = {"description": "Logged Strings","public": False,"files": {"log.txt": {"content": content}}}
|
| 370 |
+
headers = {'Authorization': f'token {github_token}','Accept': 'application/vnd.github.v3+json'}
|
| 371 |
+
response = requests.patch(f'https://api.github.com/gists/{gist_id}', headers=headers, data=json.dumps(data)) # Update existing gist
|
| 372 |
+
return
|
| 373 |
+
|
| 374 |
class OverallConsensusEvaluation(BaseModel):
|
| 375 |
rewritten_statement: str = Field(
|
| 376 |
...,
|
|
|
|
| 486 |
search_text_list = ['rooting around in the paper pile...','looking for clarity...','scanning the event horizon...','peering into the abyss...','potatoes power this ongoing search...']
|
| 487 |
gen_text_list = ['making the LLM talk to the papers...','invoking arcane rituals...','gone to library, please wait...','is there really an answer to this...']
|
| 488 |
|
| 489 |
+
log_to_gist(['[mod flag: '+str(check_mod(query))+']', query])
|
| 490 |
+
if check_mod(query) == False:
|
| 491 |
+
|
| 492 |
+
input_keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else []
|
| 493 |
+
query_keywords = get_keywords(query)
|
| 494 |
+
ec.query_input_keywords = input_keywords+query_keywords
|
| 495 |
+
ec.toggles = toggles
|
| 496 |
+
if rag_type == "Semantic Search":
|
| 497 |
+
ec.hyde = False
|
| 498 |
+
ec.rerank = False
|
| 499 |
+
elif rag_type == "Semantic + HyDE":
|
| 500 |
+
ec.hyde = True
|
| 501 |
+
ec.rerank = False
|
| 502 |
+
elif rag_type == "Semantic + HyDE + CoHERE":
|
| 503 |
+
ec.hyde = True
|
| 504 |
+
ec.rerank = True
|
| 505 |
+
|
| 506 |
+
progress(0.2, desc=search_text_list[np.random.choice(len(search_text_list))])
|
| 507 |
+
rs, small_df = ec.retrieve(query, top_k = top_k, return_scores=True)
|
| 508 |
+
formatted_df = ec.return_formatted_df(rs, small_df)
|
| 509 |
+
yield formatted_df, None, None, None, None
|
| 510 |
+
|
| 511 |
+
progress(0.4, desc=gen_text_list[np.random.choice(len(gen_text_list))])
|
| 512 |
+
rag_answer = run_rag_qa(query, formatted_df, prompt_type)
|
| 513 |
+
yield formatted_df, rag_answer['answer'], None, None, None
|
| 514 |
+
|
| 515 |
+
progress(0.6, desc="Generating consensus")
|
| 516 |
+
consensus_answer = evaluate_overall_consensus(query, [formatted_df['abstract'][i+1] for i in range(len(formatted_df))])
|
| 517 |
+
consensus = '## Consensus \n'+consensus_answer.consensus + '\n\n'+consensus_answer.explanation + '\n\n > Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score
|
| 518 |
+
yield formatted_df, rag_answer['answer'], consensus, None, None
|
| 519 |
+
|
| 520 |
+
progress(0.8, desc="Analyzing question type")
|
| 521 |
+
question_type_gen = guess_question_type(query)
|
| 522 |
+
if '<categorization>' in question_type_gen:
|
| 523 |
+
question_type_gen = question_type_gen.split('<categorization>')[1]
|
| 524 |
+
if '</categorization>' in question_type_gen:
|
| 525 |
+
question_type_gen = question_type_gen.split('</categorization>')[0]
|
| 526 |
+
question_type_gen = question_type_gen.replace('\n',' \n')
|
| 527 |
+
qn_type = question_type_gen
|
| 528 |
+
yield formatted_df, rag_answer['answer'], consensus, qn_type, None
|
| 529 |
+
|
| 530 |
+
progress(1.0, desc="Visualizing embeddings")
|
| 531 |
+
fig = make_embedding_plot(formatted_df, top_k, consensus_answer)
|
| 532 |
+
|
| 533 |
+
yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
|
| 534 |
|
| 535 |
def create_interface():
|
| 536 |
custom_css = """
|