Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ from pymongo import MongoClient
|
|
12 |
from PyPDF2 import PdfReader
|
13 |
st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖")
|
14 |
|
15 |
-
|
16 |
|
17 |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
18 |
from langchain.embeddings import HuggingFaceEmbeddings
|
@@ -229,9 +229,7 @@ def init_vector_search() -> MongoDBAtlasVectorSearch:
|
|
229 |
# if len(clean) > 10 and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]):
|
230 |
# prompts.append(clean)
|
231 |
# return prompts
|
232 |
-
|
233 |
-
import os
|
234 |
-
import openai
|
235 |
|
236 |
def extract_with_llm_local(text: str, use_openai: bool = False) -> List[str]:
|
237 |
# Example context to prime the model
|
@@ -266,7 +264,7 @@ PROMPTS:
|
|
266 |
return "⚠️ OpenAI key missing."
|
267 |
try:
|
268 |
response = client.chat.completions.create(
|
269 |
-
model="gpt-
|
270 |
messages=[
|
271 |
{"role": "system", "content": "You extract prompts and headers from grant text."},
|
272 |
{"role": "user", "content": prompt},
|
@@ -276,6 +274,8 @@ PROMPTS:
|
|
276 |
)
|
277 |
# raw_output = response["choices"][0]["message"]["content"]
|
278 |
raw_output = response.choices[0].message.content
|
|
|
|
|
279 |
except Exception as e:
|
280 |
st.error(f"❌ OpenAI extraction failed: {e}")
|
281 |
return []
|
@@ -351,16 +351,12 @@ def load_local_model():
|
|
351 |
tokenizer, model = load_local_model()
|
352 |
|
353 |
def generate_response(input_dict, use_openai=False):
|
354 |
-
|
355 |
-
if not openai.api_key:
|
356 |
-
st.error("❌ OPENAI_API_KEY is not set.")
|
357 |
-
return "⚠️ OpenAI key missing."
|
358 |
-
|
359 |
-
prompt = grantbuddy_prompt.format(**input_dict)
|
360 |
|
|
|
361 |
try:
|
362 |
response = client.chat.completions.create(
|
363 |
-
model="gpt-
|
364 |
messages=[
|
365 |
{"role": "system", "content": prompt},
|
366 |
{"role": "user", "content": input_dict["question"]},
|
@@ -368,14 +364,30 @@ def generate_response(input_dict, use_openai=False):
|
|
368 |
temperature=0.2,
|
369 |
max_tokens=700,
|
370 |
)
|
371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
except Exception as e:
|
373 |
st.error(f"❌ OpenAI error: {e}")
|
374 |
-
return
|
|
|
|
|
|
|
375 |
|
376 |
else:
|
377 |
-
# Local TinyLlama path
|
378 |
-
prompt = grantbuddy_prompt.format(**input_dict)
|
379 |
inputs = tokenizer(prompt, return_tensors="pt")
|
380 |
outputs = model.generate(
|
381 |
**inputs,
|
@@ -385,17 +397,31 @@ def generate_response(input_dict, use_openai=False):
|
|
385 |
pad_token_id=tokenizer.eos_token_id
|
386 |
)
|
387 |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
388 |
-
return
|
|
|
|
|
|
|
389 |
|
390 |
|
391 |
|
392 |
|
393 |
# =================== RAG Chain ===================
|
394 |
def get_rag_chain(retriever, use_openai=False):
|
395 |
-
|
396 |
-
|
397 |
-
"
|
398 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
|
400 |
# =================== Streamlit UI ===================
|
401 |
def main():
|
@@ -404,7 +430,8 @@ def main():
|
|
404 |
USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False)
|
405 |
if "generated_queries" not in st.session_state:
|
406 |
st.session_state.generated_queries = {}
|
407 |
-
|
|
|
408 |
|
409 |
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
|
410 |
rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI)
|
@@ -440,12 +467,18 @@ def main():
|
|
440 |
selected_questions.append(q)
|
441 |
submit_button = st.form_submit_button("Submit")
|
442 |
|
|
|
443 |
if 'submit_button' in locals() and submit_button:
|
444 |
if selected_questions:
|
445 |
with st.spinner("💡 Generating answers..."):
|
446 |
answers = []
|
447 |
for q in selected_questions:
|
448 |
-
full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
|
|
|
|
|
|
|
|
|
|
|
449 |
# response = rag_chain.invoke(full_query)
|
450 |
# answers.append({"question": q, "answer": response})
|
451 |
if q in st.session_state.generated_queries:
|
@@ -456,29 +489,16 @@ def main():
|
|
456 |
answers.append({"question": q, "answer": response})
|
457 |
for item in answers:
|
458 |
st.markdown(f"### ❓ {item['question']}")
|
459 |
-
st.markdown(f"💬 {item['answer']}")
|
|
|
|
|
|
|
|
|
|
|
460 |
else:
|
461 |
st.info("No prompts selected for answering.")
|
462 |
|
463 |
|
464 |
-
|
465 |
-
# #select prompts to answer
|
466 |
-
# selected_questions = st.multiselect("✅ Choose prompts to answer:", filtered_questions, default=filtered_questions)
|
467 |
-
|
468 |
-
# if selected_questions:
|
469 |
-
# with st.spinner("💡 Generating answers..."):
|
470 |
-
# answers = []
|
471 |
-
# for q in selected_questions:
|
472 |
-
# full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
|
473 |
-
# response = rag_chain.invoke(full_query)
|
474 |
-
# answers.append({"question": q, "answer": response})
|
475 |
-
|
476 |
-
# for item in answers:
|
477 |
-
# st.markdown(f"### ❓ {item['question']}")
|
478 |
-
# st.markdown(f"💬 {item['answer']}")
|
479 |
-
# else:
|
480 |
-
# st.info("No prompts selected for answering.")
|
481 |
-
|
482 |
# ✍️ Manual single-question input
|
483 |
query = st.text_input("Ask a grant-related question")
|
484 |
if st.button("Submit"):
|
@@ -486,13 +506,19 @@ def main():
|
|
486 |
st.warning("Please enter a question.")
|
487 |
return
|
488 |
|
489 |
-
full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query
|
|
|
490 |
with st.spinner("🤖 Thinking..."):
|
491 |
-
response = rag_chain.invoke(full_query)
|
492 |
-
|
493 |
-
|
|
|
|
|
|
|
|
|
|
|
494 |
with st.expander("🔍 Retrieved Chunks"):
|
495 |
-
context_docs = retriever.get_relevant_documents(
|
496 |
for doc in context_docs:
|
497 |
# st.json(doc.metadata)
|
498 |
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}")
|
|
|
12 |
from PyPDF2 import PdfReader
|
13 |
st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖")
|
14 |
|
15 |
+
from typing import List
|
16 |
|
17 |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
18 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
229 |
# if len(clean) > 10 and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]):
|
230 |
# prompts.append(clean)
|
231 |
# return prompts
|
232 |
+
|
|
|
|
|
233 |
|
234 |
def extract_with_llm_local(text: str, use_openai: bool = False) -> List[str]:
|
235 |
# Example context to prime the model
|
|
|
264 |
return "⚠️ OpenAI key missing."
|
265 |
try:
|
266 |
response = client.chat.completions.create(
|
267 |
+
model="gpt-4o-mini",
|
268 |
messages=[
|
269 |
{"role": "system", "content": "You extract prompts and headers from grant text."},
|
270 |
{"role": "user", "content": prompt},
|
|
|
274 |
)
|
275 |
# raw_output = response["choices"][0]["message"]["content"]
|
276 |
raw_output = response.choices[0].message.content
|
277 |
+
st.markdown(f"🧮 Extract Tokens: Prompt = {response.usage.prompt_tokens}, "
|
278 |
+
f"Completion = {response.usage.completion_tokens}, Total = {response.usage.total_tokens}")
|
279 |
except Exception as e:
|
280 |
st.error(f"❌ OpenAI extraction failed: {e}")
|
281 |
return []
|
|
|
351 |
tokenizer, model = load_local_model()
|
352 |
|
353 |
def generate_response(input_dict, use_openai=False):
|
354 |
+
prompt = grantbuddy_prompt.format(**input_dict)
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
+
if use_openai:
|
357 |
try:
|
358 |
response = client.chat.completions.create(
|
359 |
+
model="gpt-4o-mini",
|
360 |
messages=[
|
361 |
{"role": "system", "content": prompt},
|
362 |
{"role": "user", "content": input_dict["question"]},
|
|
|
364 |
temperature=0.2,
|
365 |
max_tokens=700,
|
366 |
)
|
367 |
+
answer = response.choices[0].message.content.strip()
|
368 |
+
|
369 |
+
# ✅ Token logging
|
370 |
+
prompt_tokens = response.usage.prompt_tokens
|
371 |
+
completion_tokens = response.usage.completion_tokens
|
372 |
+
total_tokens = response.usage.total_tokens
|
373 |
+
|
374 |
+
return {
|
375 |
+
"answer": answer,
|
376 |
+
"tokens": {
|
377 |
+
"prompt": prompt_tokens,
|
378 |
+
"completion": completion_tokens,
|
379 |
+
"total": total_tokens
|
380 |
+
}
|
381 |
+
}
|
382 |
+
|
383 |
except Exception as e:
|
384 |
st.error(f"❌ OpenAI error: {e}")
|
385 |
+
return {
|
386 |
+
"answer": "⚠️ OpenAI request failed.",
|
387 |
+
"tokens": {}
|
388 |
+
}
|
389 |
|
390 |
else:
|
|
|
|
|
391 |
inputs = tokenizer(prompt, return_tensors="pt")
|
392 |
outputs = model.generate(
|
393 |
**inputs,
|
|
|
397 |
pad_token_id=tokenizer.eos_token_id
|
398 |
)
|
399 |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
400 |
+
return {
|
401 |
+
"answer": decoded[len(prompt):].strip(),
|
402 |
+
"tokens": {}
|
403 |
+
}
|
404 |
|
405 |
|
406 |
|
407 |
|
408 |
# =================== RAG Chain ===================
|
409 |
def get_rag_chain(retriever, use_openai=False):
|
410 |
+
def merge_contexts(inputs):
|
411 |
+
retrieved_chunks = format_docs(retriever.invoke(inputs["question"]))
|
412 |
+
combined = "\n\n".join(filter(None, [
|
413 |
+
inputs.get("manual_context", ""),
|
414 |
+
retrieved_chunks
|
415 |
+
]))
|
416 |
+
return {
|
417 |
+
"context": combined,
|
418 |
+
"question": inputs["question"]
|
419 |
+
}
|
420 |
+
|
421 |
+
return RunnableLambda(merge_contexts) | RunnableLambda(
|
422 |
+
lambda input_dict: generate_response(input_dict, use_openai=use_openai)
|
423 |
+
)
|
424 |
+
|
425 |
|
426 |
# =================== Streamlit UI ===================
|
427 |
def main():
|
|
|
430 |
USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False)
|
431 |
if "generated_queries" not in st.session_state:
|
432 |
st.session_state.generated_queries = {}
|
433 |
+
|
434 |
+
manual_context = st.text_area("📝 Optional: Add your own context (e.g., mission, goals)", height=150)
|
435 |
|
436 |
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
|
437 |
rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI)
|
|
|
467 |
selected_questions.append(q)
|
468 |
submit_button = st.form_submit_button("Submit")
|
469 |
|
470 |
+
#Multi-Select Question
|
471 |
if 'submit_button' in locals() and submit_button:
|
472 |
if selected_questions:
|
473 |
with st.spinner("💡 Generating answers..."):
|
474 |
answers = []
|
475 |
for q in selected_questions:
|
476 |
+
# full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
|
477 |
+
combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
|
478 |
+
response = rag_chain.invoke({
|
479 |
+
"question": q,
|
480 |
+
"manual_context": combined_context
|
481 |
+
})
|
482 |
# response = rag_chain.invoke(full_query)
|
483 |
# answers.append({"question": q, "answer": response})
|
484 |
if q in st.session_state.generated_queries:
|
|
|
489 |
answers.append({"question": q, "answer": response})
|
490 |
for item in answers:
|
491 |
st.markdown(f"### ❓ {item['question']}")
|
492 |
+
st.markdown(f"💬 {item['answer']['answer']}")
|
493 |
+
tokens = item['answer'].get("tokens", {})
|
494 |
+
if tokens:
|
495 |
+
st.markdown(f"🧮 **Token Usage:** Prompt = {tokens.get('prompt')}, "
|
496 |
+
f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
|
497 |
+
|
498 |
else:
|
499 |
st.info("No prompts selected for answering.")
|
500 |
|
501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
# ✍️ Manual single-question input
|
503 |
query = st.text_input("Ask a grant-related question")
|
504 |
if st.button("Submit"):
|
|
|
506 |
st.warning("Please enter a question.")
|
507 |
return
|
508 |
|
509 |
+
# full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query
|
510 |
+
combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
|
511 |
with st.spinner("🤖 Thinking..."):
|
512 |
+
# response = rag_chain.invoke(full_query)
|
513 |
+
response = rag_chain.invoke({"question":query,"manual_context": combined_context})
|
514 |
+
st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True)
|
515 |
+
tokens=response.get("tokens",{})
|
516 |
+
if tokens:
|
517 |
+
st.markdown(f"🧮 **Token Usage:** Prompt = {tokens.get('prompt')}, "
|
518 |
+
f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
|
519 |
+
|
520 |
with st.expander("🔍 Retrieved Chunks"):
|
521 |
+
context_docs = retriever.get_relevant_documents(query)
|
522 |
for doc in context_docs:
|
523 |
# st.json(doc.metadata)
|
524 |
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}")
|