Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -161,30 +161,60 @@ def extract_with_llm(text: str) -> List[str]:
|
|
161 |
st.error(str(e))
|
162 |
return []
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
# =================== Format Retrieved Chunks ===================
|
166 |
def format_docs(docs: List[Document]) -> str:
|
167 |
return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs)
|
168 |
|
169 |
# =================== Generate Response from Hugging Face Model ===================
|
170 |
-
def generate_response(input_dict: Dict[str, Any]) -> str:
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
prompt = grantbuddy_prompt.format(**input_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
-
try:
|
175 |
-
response = client.chat.completions.create(
|
176 |
-
model="HuggingFaceH4/zephyr-7b-beta",
|
177 |
-
messages=[
|
178 |
-
{"role": "system", "content": prompt},
|
179 |
-
{"role": "user", "content": input_dict["question"]},
|
180 |
-
],
|
181 |
-
max_tokens=1000,
|
182 |
-
temperature=0.2,
|
183 |
-
)
|
184 |
-
return response.choices[0].message.content
|
185 |
-
except Exception as e:
|
186 |
-
st.error(f"β Error from model: {e}")
|
187 |
-
return "β οΈ Failed to generate response. Please check your model, HF token, or request format."
|
188 |
|
189 |
|
190 |
# =================== RAG Chain ===================
|
@@ -199,13 +229,12 @@ def main():
|
|
199 |
st.set_page_config(page_title="Grant Buddy RAG", page_icon="π€")
|
200 |
st.title("π€ Grant Buddy: Grant-Writing Assistant")
|
201 |
|
|
|
|
|
|
|
202 |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
|
203 |
uploaded_text = ""
|
204 |
|
205 |
-
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
|
206 |
-
rag_chain = get_rag_chain(retriever) # β
Initialize before usage
|
207 |
-
|
208 |
-
# π Process uploaded file
|
209 |
if uploaded_file:
|
210 |
with st.spinner("π Processing uploaded file..."):
|
211 |
if uploaded_file.name.endswith(".pdf"):
|
@@ -214,26 +243,40 @@ def main():
|
|
214 |
elif uploaded_file.name.endswith(".txt"):
|
215 |
uploaded_text = uploaded_file.read().decode("utf-8")
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
answers
|
|
|
|
|
|
|
|
|
228 |
|
229 |
for item in answers:
|
230 |
st.markdown(f"### β {item['question']}")
|
231 |
st.markdown(f"π¬ {item['answer']}")
|
|
|
|
|
232 |
|
233 |
-
#
|
234 |
query = st.text_input("Ask a grant-related question")
|
235 |
if st.button("Submit"):
|
236 |
-
if not query
|
237 |
st.warning("Please enter a question.")
|
238 |
return
|
239 |
|
@@ -245,13 +288,14 @@ def main():
|
|
245 |
with st.expander("π Retrieved Chunks"):
|
246 |
context_docs = retriever.get_relevant_documents(full_query)
|
247 |
for doc in context_docs:
|
248 |
-
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown'
|
249 |
st.markdown(doc.page_content[:700] + "...")
|
250 |
st.markdown("---")
|
251 |
|
252 |
|
253 |
|
254 |
|
|
|
255 |
if __name__ == "__main__":
|
256 |
main()
|
257 |
|
|
|
161 |
st.error(str(e))
|
162 |
return []
|
163 |
|
164 |
+
# def is_meaningful_prompt(text: str) -> bool:
|
165 |
+
# too_short = len(text.strip()) < 10
|
166 |
+
# banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"]
|
167 |
+
# contains_bad_word = any(word in text.lower() for word in banned_keywords)
|
168 |
+
# is_just_punctuation = all(c in ":.*- " for c in text.strip())
|
169 |
+
|
170 |
+
# return not (too_short or contains_bad_word or is_just_punctuation)
|
171 |
|
172 |
# =================== Format Retrieved Chunks ===================
|
173 |
def format_docs(docs: List[Document]) -> str:
|
174 |
return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs)
|
175 |
|
176 |
# =================== Generate Response from Hugging Face Model ===================
|
177 |
+
# def generate_response(input_dict: Dict[str, Any]) -> str:
|
178 |
+
# client = InferenceClient(api_key=HF_TOKEN.strip())
|
179 |
+
# prompt = grantbuddy_prompt.format(**input_dict)
|
180 |
+
|
181 |
+
# try:
|
182 |
+
# response = client.chat.completions.create(
|
183 |
+
# model="HuggingFaceH4/zephyr-7b-beta",
|
184 |
+
# messages=[
|
185 |
+
# {"role": "system", "content": prompt},
|
186 |
+
# {"role": "user", "content": input_dict["question"]},
|
187 |
+
# ],
|
188 |
+
# max_tokens=1000,
|
189 |
+
# temperature=0.2,
|
190 |
+
# )
|
191 |
+
# return response.choices[0].message.content
|
192 |
+
# except Exception as e:
|
193 |
+
# st.error(f"β Error from model: {e}")
|
194 |
+
# return "β οΈ Failed to generate response. Please check your model, HF token, or request format."
|
195 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
196 |
+
import torch
|
197 |
+
|
198 |
+
@st.cache_resource
|
199 |
+
def load_local_model():
|
200 |
+
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
201 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
202 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
203 |
+
return tokenizer, model
|
204 |
+
|
205 |
+
tokenizer, model = load_local_model()
|
206 |
+
|
207 |
+
def generate_response(input_dict):
|
208 |
prompt = grantbuddy_prompt.format(**input_dict)
|
209 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
210 |
+
outputs = model.generate(
|
211 |
+
**inputs,
|
212 |
+
max_new_tokens=512,
|
213 |
+
temperature=0.7,
|
214 |
+
do_sample=True
|
215 |
+
)
|
216 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True).split("QUESTION:")[-1].strip()
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
|
220 |
# =================== RAG Chain ===================
|
|
|
229 |
st.set_page_config(page_title="Grant Buddy RAG", page_icon="π€")
|
230 |
st.title("π€ Grant Buddy: Grant-Writing Assistant")
|
231 |
|
232 |
+
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
|
233 |
+
rag_chain = get_rag_chain(retriever)
|
234 |
+
|
235 |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
|
236 |
uploaded_text = ""
|
237 |
|
|
|
|
|
|
|
|
|
238 |
if uploaded_file:
|
239 |
with st.spinner("π Processing uploaded file..."):
|
240 |
if uploaded_file.name.endswith(".pdf"):
|
|
|
243 |
elif uploaded_file.name.endswith(".txt"):
|
244 |
uploaded_text = uploaded_file.read().decode("utf-8")
|
245 |
|
246 |
+
# π§ Extract prompts using LLM
|
247 |
+
questions = extract_with_llm(uploaded_text)
|
248 |
+
|
249 |
+
# π« Filter out irrelevant junk
|
250 |
+
def is_meaningful_prompt(text: str) -> bool:
|
251 |
+
too_short = len(text.strip()) < 10
|
252 |
+
banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"]
|
253 |
+
contains_bad_word = any(word in text.lower() for word in banned_keywords)
|
254 |
+
is_just_punctuation = all(c in ":.*- " for c in text.strip())
|
255 |
+
return not (too_short or contains_bad_word or is_just_punctuation)
|
256 |
+
|
257 |
+
filtered_questions = [q for q in questions if is_meaningful_prompt(q)]
|
258 |
|
259 |
+
# π― Prompt selection UI
|
260 |
+
selected_questions = st.multiselect("β
Choose prompts to answer:", filtered_questions, default=filtered_questions)
|
261 |
+
|
262 |
+
if selected_questions:
|
263 |
+
with st.spinner("π‘ Generating answers..."):
|
264 |
+
answers = []
|
265 |
+
for q in selected_questions:
|
266 |
+
full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
|
267 |
+
response = rag_chain.invoke(full_query)
|
268 |
+
answers.append({"question": q, "answer": response})
|
269 |
|
270 |
for item in answers:
|
271 |
st.markdown(f"### β {item['question']}")
|
272 |
st.markdown(f"π¬ {item['answer']}")
|
273 |
+
else:
|
274 |
+
st.info("No prompts selected for answering.")
|
275 |
|
276 |
+
# βοΈ Manual single-question input
|
277 |
query = st.text_input("Ask a grant-related question")
|
278 |
if st.button("Submit"):
|
279 |
+
if not query:
|
280 |
st.warning("Please enter a question.")
|
281 |
return
|
282 |
|
|
|
288 |
with st.expander("π Retrieved Chunks"):
|
289 |
context_docs = retriever.get_relevant_documents(full_query)
|
290 |
for doc in context_docs:
|
291 |
+
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata.get('title', 'unknown')}")
|
292 |
st.markdown(doc.page_content[:700] + "...")
|
293 |
st.markdown("---")
|
294 |
|
295 |
|
296 |
|
297 |
|
298 |
+
|
299 |
if __name__ == "__main__":
|
300 |
main()
|
301 |
|