Spaces:
Runtime error
Runtime error
Update processing.py
Browse files- processing.py +19 -4
processing.py
CHANGED
@@ -5,7 +5,8 @@ from langchain_community.vectorstores import FAISS
|
|
5 |
from llm_loader import load_model
|
6 |
from config import openai_api_key
|
7 |
from langchain.chains import RetrievalQA, LLMChain
|
8 |
-
from langchain.retrievers import MultiQueryRetriever
|
|
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
import os
|
11 |
import json
|
@@ -42,8 +43,22 @@ text_retriever = text_faiss_index.as_retriever()
|
|
42 |
attachments_retriever = attachments_faiss_index.as_retriever()
|
43 |
personalities_retriever = personalities_faiss_index.as_retriever()
|
44 |
|
45 |
-
# Create a
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
# Create prompt template for query generation
|
49 |
prompt_template = PromptTemplate(
|
@@ -61,7 +76,7 @@ multi_query_retriever = MultiQueryRetriever(
|
|
61 |
parser_key="lines" # Assuming the LLM outputs queries line by line
|
62 |
)
|
63 |
|
64 |
-
#
|
65 |
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
|
66 |
|
67 |
def load_text(file_path: str) -> str:
|
|
|
5 |
from llm_loader import load_model
|
6 |
from config import openai_api_key
|
7 |
from langchain.chains import RetrievalQA, LLMChain
|
8 |
+
from langchain.retrievers import MultiQueryRetriever, MultiVectorRetriever
|
9 |
+
from langchain.schema import Document
|
10 |
from langchain.prompts import PromptTemplate
|
11 |
import os
|
12 |
import json
|
|
|
43 |
attachments_retriever = attachments_faiss_index.as_retriever()
|
44 |
personalities_retriever = personalities_faiss_index.as_retriever()
|
45 |
|
46 |
+
# Create a list of all retrievers
|
47 |
+
all_retrievers = [text_retriever, attachments_retriever, personalities_retriever]
|
48 |
+
|
49 |
+
# Create a function to combine results from all retrievers
|
50 |
+
def combined_retriever_function(query):
|
51 |
+
combined_docs = []
|
52 |
+
for retriever in all_retrievers:
|
53 |
+
docs = retriever.get_relevant_documents(query)
|
54 |
+
combined_docs.extend(docs)
|
55 |
+
return combined_docs
|
56 |
+
|
57 |
+
# Create a MultiVectorRetriever
|
58 |
+
combined_retriever = MultiVectorRetriever(
|
59 |
+
retrievers=all_retrievers,
|
60 |
+
retriever_query=combined_retriever_function
|
61 |
+
)
|
62 |
|
63 |
# Create prompt template for query generation
|
64 |
prompt_template = PromptTemplate(
|
|
|
76 |
parser_key="lines" # Assuming the LLM outputs queries line by line
|
77 |
)
|
78 |
|
79 |
+
# Use the multi-query retriever in the QA chain
|
80 |
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
|
81 |
|
82 |
def load_text(file_path: str) -> str:
|