Spaces:
Runtime error
Runtime error
Update processing.py
Browse files- processing.py +18 -15
processing.py
CHANGED
@@ -1,15 +1,17 @@
|
|
|
|
|
|
1 |
from langchain.schema import HumanMessage
|
2 |
from output_parser import attachment_parser, bigfive_parser, personality_parser
|
3 |
-
from
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
from llm_loader import load_model
|
6 |
from config import openai_api_key
|
7 |
from langchain.chains import RetrievalQA
|
8 |
-
import json
|
9 |
import os
|
|
|
10 |
|
11 |
# Initialize embeddings and FAISS index
|
12 |
-
embedding_model = OpenAIEmbeddings(
|
13 |
|
14 |
# Path to the knowledge files
|
15 |
knowledge_files = {
|
@@ -28,18 +30,11 @@ for key, file_path in knowledge_files.items():
|
|
28 |
# Create a FAISS index from the knowledge documents
|
29 |
faiss_index = FAISS.from_texts(documents, embedding_model)
|
30 |
|
31 |
-
#
|
32 |
-
|
33 |
-
|
34 |
-
# If you want to load the FAISS index later, use this:
|
35 |
-
# faiss_index = FAISS.load_local("faiss_index", embedding_model)
|
36 |
-
|
37 |
-
# Load the LLM using llm_loader.py
|
38 |
-
llm = load_model(openai_api_key) # Load the model using your custom loader
|
39 |
|
40 |
# Initialize the retrieval chain
|
41 |
-
|
42 |
-
qa_chain = RetrievalQA.from_llm(llm=llm, retriever=retriever)
|
43 |
|
44 |
def load_text(file_path: str) -> str:
|
45 |
with open(file_path, 'r', encoding='utf-8') as file:
|
@@ -59,7 +54,15 @@ def process_task(llm, input_text: str, general_task: str, specific_task: str, ou
|
|
59 |
|
60 |
# Perform retrieval to get the most relevant context
|
61 |
relevant_docs = qa_chain({"query": truncated_input})
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
# Combine the retrieved knowledge with the original prompt
|
65 |
prompt = f"""{general_task}
|
@@ -111,4 +114,4 @@ def process_input(input_text: str, llm):
|
|
111 |
results[speaker_id] = {}
|
112 |
results[speaker_id][task_name] = speaker_result
|
113 |
|
114 |
-
return results
|
|
|
1 |
+
# processing.py
|
2 |
+
|
3 |
from langchain.schema import HumanMessage
|
4 |
from output_parser import attachment_parser, bigfive_parser, personality_parser
|
5 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
6 |
from langchain_community.vectorstores import FAISS
|
7 |
from llm_loader import load_model
|
8 |
from config import openai_api_key
|
9 |
from langchain.chains import RetrievalQA
|
|
|
10 |
import os
|
11 |
+
import json
|
12 |
|
13 |
# Initialize embeddings and FAISS index
|
14 |
+
embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
15 |
|
16 |
# Path to the knowledge files
|
17 |
knowledge_files = {
|
|
|
30 |
# Create a FAISS index from the knowledge documents
|
31 |
faiss_index = FAISS.from_texts(documents, embedding_model)
|
32 |
|
33 |
+
# Load the LLM
|
34 |
+
llm = load_model(openai_api_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
# Initialize the retrieval chain
|
37 |
+
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=faiss_index.as_retriever())
|
|
|
38 |
|
39 |
def load_text(file_path: str) -> str:
|
40 |
with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
54 |
|
55 |
# Perform retrieval to get the most relevant context
|
56 |
relevant_docs = qa_chain({"query": truncated_input})
|
57 |
+
|
58 |
+
# Print the structure of relevant_docs for debugging
|
59 |
+
print("Structure of relevant_docs:", json.dumps(relevant_docs, indent=2, default=str))
|
60 |
+
|
61 |
+
# Extract the retrieved knowledge
|
62 |
+
if isinstance(relevant_docs, dict) and 'result' in relevant_docs:
|
63 |
+
retrieved_knowledge = relevant_docs['result']
|
64 |
+
else:
|
65 |
+
retrieved_knowledge = str(relevant_docs)
|
66 |
|
67 |
# Combine the retrieved knowledge with the original prompt
|
68 |
prompt = f"""{general_task}
|
|
|
114 |
results[speaker_id] = {}
|
115 |
results[speaker_id][task_name] = speaker_result
|
116 |
|
117 |
+
return results
|