Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import json
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import pandas as pd
|
| 5 |
from tempfile import NamedTemporaryFile
|
|
@@ -13,6 +14,8 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
| 13 |
from langchain_community.llms import HuggingFaceHub
|
| 14 |
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
| 15 |
from langchain_core.documents import Document
|
|
|
|
|
|
|
| 16 |
|
| 17 |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
| 18 |
|
|
@@ -58,15 +61,24 @@ def clear_cache():
|
|
| 58 |
else:
|
| 59 |
return "No cache to clear."
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
prompt = """
|
| 62 |
-
Answer the question based on the following
|
|
|
|
| 63 |
Conversation History:
|
| 64 |
{history}
|
| 65 |
|
| 66 |
Context from documents:
|
| 67 |
{context}
|
| 68 |
|
| 69 |
-
Question: {question}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
Provide a concise and direct answer to the question:
|
| 72 |
"""
|
|
@@ -100,6 +112,13 @@ def manage_conversation_history(question, answer, history, max_history=5):
|
|
| 100 |
history.pop(0)
|
| 101 |
return history
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
def ask_question(question, temperature, top_p, repetition_penalty):
|
| 104 |
global conversation_history
|
| 105 |
|
|
@@ -114,14 +133,19 @@ def ask_question(question, temperature, top_p, repetition_penalty):
|
|
| 114 |
model = get_model(temperature, top_p, repetition_penalty)
|
| 115 |
|
| 116 |
history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
prompt_val = ChatPromptTemplate.from_template(prompt)
|
| 118 |
-
retriever = database.as_retriever()
|
| 119 |
-
relevant_docs = retriever.get_relevant_documents(question)
|
| 120 |
-
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
| 121 |
formatted_prompt = prompt_val.format(history=history_str, context=context_str, question=question)
|
| 122 |
|
| 123 |
answer = generate_chunked_response(model, formatted_prompt)
|
| 124 |
-
answer =
|
| 125 |
|
| 126 |
memory_database[question] = answer
|
| 127 |
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
+
import re
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
| 6 |
from tempfile import NamedTemporaryFile
|
|
|
|
| 14 |
from langchain_community.llms import HuggingFaceHub
|
| 15 |
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
| 16 |
from langchain_core.documents import Document
|
| 17 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 18 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 19 |
|
| 20 |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
| 21 |
|
|
|
|
| 61 |
else:
|
| 62 |
return "No cache to clear."
|
| 63 |
|
| 64 |
+
def get_similarity(text1, text2):
|
| 65 |
+
vectorizer = TfidfVectorizer().fit_transform([text1, text2])
|
| 66 |
+
return cosine_similarity(vectorizer[0:1], vectorizer[1:2])[0][0]
|
| 67 |
+
|
| 68 |
prompt = """
|
| 69 |
+
Answer the question based on the following information:
|
| 70 |
+
|
| 71 |
Conversation History:
|
| 72 |
{history}
|
| 73 |
|
| 74 |
Context from documents:
|
| 75 |
{context}
|
| 76 |
|
| 77 |
+
Current Question: {question}
|
| 78 |
+
|
| 79 |
+
If the question is referring to the conversation history, use that information to answer.
|
| 80 |
+
If the question is not related to the conversation history, use the context from documents to answer.
|
| 81 |
+
If you don't have enough information to answer, say so.
|
| 82 |
|
| 83 |
Provide a concise and direct answer to the question:
|
| 84 |
"""
|
|
|
|
| 112 |
history.pop(0)
|
| 113 |
return history
|
| 114 |
|
| 115 |
+
def is_related_to_history(question, history, threshold=0.3):
|
| 116 |
+
if not history:
|
| 117 |
+
return False
|
| 118 |
+
history_text = " ".join([f"{h['question']} {h['answer']}" for h in history])
|
| 119 |
+
similarity = get_similarity(question, history_text)
|
| 120 |
+
return similarity > threshold
|
| 121 |
+
|
| 122 |
def ask_question(question, temperature, top_p, repetition_penalty):
|
| 123 |
global conversation_history
|
| 124 |
|
|
|
|
| 133 |
model = get_model(temperature, top_p, repetition_penalty)
|
| 134 |
|
| 135 |
history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
|
| 136 |
+
|
| 137 |
+
if is_related_to_history(question, conversation_history):
|
| 138 |
+
context_str = "No additional context needed. Please refer to the conversation history."
|
| 139 |
+
else:
|
| 140 |
+
retriever = database.as_retriever()
|
| 141 |
+
relevant_docs = retriever.get_relevant_documents(question)
|
| 142 |
+
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
| 143 |
+
|
| 144 |
prompt_val = ChatPromptTemplate.from_template(prompt)
|
|
|
|
|
|
|
|
|
|
| 145 |
formatted_prompt = prompt_val.format(history=history_str, context=context_str, question=question)
|
| 146 |
|
| 147 |
answer = generate_chunked_response(model, formatted_prompt)
|
| 148 |
+
answer = re.split(r'Question:|Current Question:', answer)[-1].strip()
|
| 149 |
|
| 150 |
memory_database[question] = answer
|
| 151 |
|