Spaces:
Runtime error
Runtime error
Update processing.py
Browse files- processing.py +21 -7
processing.py
CHANGED
@@ -4,8 +4,9 @@ from langchain_openai import OpenAIEmbeddings
|
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
from llm_loader import load_model
|
6 |
from config import openai_api_key
|
7 |
-
from langchain.chains import RetrievalQA
|
8 |
from langchain.retrievers import MultiQueryRetriever
|
|
|
9 |
import os
|
10 |
import json
|
11 |
|
@@ -41,14 +42,27 @@ text_retriever = text_faiss_index.as_retriever()
|
|
41 |
attachments_retriever = attachments_faiss_index.as_retriever()
|
42 |
personalities_retriever = personalities_faiss_index.as_retriever()
|
43 |
|
44 |
-
#
|
45 |
-
combined_retriever =
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
)
|
49 |
|
50 |
-
# Create QA chain with
|
51 |
-
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=
|
52 |
|
53 |
def load_text(file_path: str) -> str:
|
54 |
with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
from llm_loader import load_model
|
6 |
from config import openai_api_key
|
7 |
+
from langchain.chains import RetrievalQA, LLMChain
|
8 |
from langchain.retrievers import MultiQueryRetriever
|
9 |
+
from langchain.prompts import PromptTemplate
|
10 |
import os
|
11 |
import json
|
12 |
|
|
|
42 |
attachments_retriever = attachments_faiss_index.as_retriever()
|
43 |
personalities_retriever = personalities_faiss_index.as_retriever()
|
44 |
|
45 |
+
# Create a combined retriever
|
46 |
+
combined_retriever = text_retriever.add_retrievers([attachments_retriever, personalities_retriever])
|
47 |
+
|
48 |
+
# Create prompt template for query generation
|
49 |
+
prompt_template = PromptTemplate(
|
50 |
+
input_variables=["question"],
|
51 |
+
template="Generate multiple search queries for the following question: {question}"
|
52 |
+
)
|
53 |
+
|
54 |
+
# Create LLM chain for query generation
|
55 |
+
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
|
56 |
+
|
57 |
+
# Initialize MultiQueryRetriever
|
58 |
+
multi_query_retriever = MultiQueryRetriever(
|
59 |
+
retriever=combined_retriever,
|
60 |
+
llm_chain=llm_chain,
|
61 |
+
parser_key="lines" # Assuming the LLM outputs queries line by line
|
62 |
)
|
63 |
|
64 |
+
# Create QA chain with multi-query retriever
|
65 |
+
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
|
66 |
|
67 |
def load_text(file_path: str) -> str:
|
68 |
with open(file_path, 'r', encoding='utf-8') as file:
|