Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -350,7 +350,7 @@ def load_local_model():
|
|
350 |
|
351 |
tokenizer, model = load_local_model()
|
352 |
|
353 |
-
def generate_response(input_dict, use_openai=False):
|
354 |
prompt = grantbuddy_prompt.format(**input_dict)
|
355 |
|
356 |
if use_openai:
|
@@ -419,7 +419,7 @@ def get_rag_chain(retriever, use_openai=False):
|
|
419 |
}
|
420 |
|
421 |
return RunnableLambda(merge_contexts) | RunnableLambda(
|
422 |
-
lambda input_dict: generate_response(input_dict, use_openai=use_openai)
|
423 |
)
|
424 |
|
425 |
|
@@ -428,13 +428,21 @@ def main():
|
|
428 |
# st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖")
|
429 |
st.title("🤖 Grant Buddy: Grant-Writing Assistant")
|
430 |
USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
if "generated_queries" not in st.session_state:
|
432 |
st.session_state.generated_queries = {}
|
433 |
|
434 |
manual_context = st.text_area("📝 Optional: Add your own context (e.g., mission, goals)", height=150)
|
435 |
|
436 |
-
retriever = init_vector_search().as_retriever(search_kwargs={"k":
|
437 |
-
rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI)
|
438 |
|
439 |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
|
440 |
uploaded_text = ""
|
@@ -475,16 +483,13 @@ def main():
|
|
475 |
for q in selected_questions:
|
476 |
# full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
|
477 |
combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
|
478 |
-
response = rag_chain.invoke({
|
479 |
-
"question": q,
|
480 |
-
"manual_context": combined_context
|
481 |
-
})
|
482 |
-
# response = rag_chain.invoke(full_query)
|
483 |
-
# answers.append({"question": q, "answer": response})
|
484 |
if q in st.session_state.generated_queries:
|
485 |
response = st.session_state.generated_queries[q]
|
486 |
else:
|
487 |
-
response = rag_chain.invoke(
|
|
|
|
|
|
|
488 |
st.session_state.generated_queries[q] = response
|
489 |
answers.append({"question": q, "answer": response})
|
490 |
for item in answers:
|
|
|
350 |
|
351 |
tokenizer, model = load_local_model()
|
352 |
|
353 |
+
def generate_response(input_dict, use_openai=False, max_tokens=700):
|
354 |
prompt = grantbuddy_prompt.format(**input_dict)
|
355 |
|
356 |
if use_openai:
|
|
|
419 |
}
|
420 |
|
421 |
return RunnableLambda(merge_contexts) | RunnableLambda(
|
422 |
+
lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens)
|
423 |
)
|
424 |
|
425 |
|
|
|
428 |
# st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖")
|
429 |
st.title("🤖 Grant Buddy: Grant-Writing Assistant")
|
430 |
USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False)
|
431 |
+
st.sidebar.markdown("### Retrieval Settings")
|
432 |
+
|
433 |
+
k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10)
|
434 |
+
score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75)
|
435 |
+
|
436 |
+
st.sidebar.markdown("### Generation Settings")
|
437 |
+
max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50)
|
438 |
+
|
439 |
if "generated_queries" not in st.session_state:
|
440 |
st.session_state.generated_queries = {}
|
441 |
|
442 |
manual_context = st.text_area("📝 Optional: Add your own context (e.g., mission, goals)", height=150)
|
443 |
|
444 |
+
retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold})
|
445 |
+
rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens)
|
446 |
|
447 |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
|
448 |
uploaded_text = ""
|
|
|
483 |
for q in selected_questions:
|
484 |
# full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
|
485 |
combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
if q in st.session_state.generated_queries:
|
487 |
response = st.session_state.generated_queries[q]
|
488 |
else:
|
489 |
+
response = rag_chain.invoke({
|
490 |
+
"question": q,
|
491 |
+
"manual_context": combined_context
|
492 |
+
})
|
493 |
st.session_state.generated_queries[q] = response
|
494 |
answers.append({"question": q, "answer": response})
|
495 |
for item in answers:
|