Spaces:
Runtime error
Runtime error
Update processing.py
Browse files- processing.py +25 -20
processing.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
from langchain.schema import HumanMessage
|
2 |
from output_parser import output_parser
|
3 |
from langchain_openai import OpenAIEmbeddings
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
from llm_loader import load_model
|
6 |
from config import openai_api_key
|
7 |
from langchain.chains import RetrievalQA, LLMChain
|
8 |
-
from langchain.retrievers import MultiQueryRetriever
|
9 |
-
from langchain.schema import Document
|
10 |
from langchain.prompts import PromptTemplate
|
|
|
11 |
import os
|
12 |
import json
|
13 |
|
@@ -43,22 +43,27 @@ text_retriever = text_faiss_index.as_retriever()
|
|
43 |
attachments_retriever = attachments_faiss_index.as_retriever()
|
44 |
personalities_retriever = personalities_faiss_index.as_retriever()
|
45 |
|
46 |
-
# Create a
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
)
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
# Create prompt template for query generation
|
64 |
prompt_template = PromptTemplate(
|
@@ -76,7 +81,7 @@ multi_query_retriever = MultiQueryRetriever(
|
|
76 |
parser_key="lines" # Assuming the LLM outputs queries line by line
|
77 |
)
|
78 |
|
79 |
-
#
|
80 |
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
|
81 |
|
82 |
def load_text(file_path: str) -> str:
|
|
|
1 |
+
from langchain.schema import HumanMessage, BaseRetriever
|
2 |
from output_parser import output_parser
|
3 |
from langchain_openai import OpenAIEmbeddings
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
from llm_loader import load_model
|
6 |
from config import openai_api_key
|
7 |
from langchain.chains import RetrievalQA, LLMChain
|
8 |
+
from langchain.retrievers import MultiQueryRetriever
|
|
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
+
from typing import List
|
11 |
import os
|
12 |
import json
|
13 |
|
|
|
43 |
attachments_retriever = attachments_faiss_index.as_retriever()
|
44 |
personalities_retriever = personalities_faiss_index.as_retriever()
|
45 |
|
46 |
+
# Create a custom combined retriever
|
47 |
+
class CombinedRetriever(BaseRetriever):
|
48 |
+
def __init__(self, retrievers: List[BaseRetriever]):
|
49 |
+
self.retrievers = retrievers
|
50 |
+
|
51 |
+
def get_relevant_documents(self, query: str):
|
52 |
+
combined_docs = []
|
53 |
+
for retriever in self.retrievers:
|
54 |
+
docs = retriever.get_relevant_documents(query)
|
55 |
+
combined_docs.extend(docs)
|
56 |
+
return combined_docs
|
57 |
+
|
58 |
+
async def aget_relevant_documents(self, query: str):
|
59 |
+
combined_docs = []
|
60 |
+
for retriever in self.retrievers:
|
61 |
+
docs = await retriever.aget_relevant_documents(query)
|
62 |
+
combined_docs.extend(docs)
|
63 |
+
return combined_docs
|
64 |
+
|
65 |
+
# Create an instance of the combined retriever
|
66 |
+
combined_retriever = CombinedRetriever([text_retriever, attachments_retriever, personalities_retriever])
|
67 |
|
68 |
# Create prompt template for query generation
|
69 |
prompt_template = PromptTemplate(
|
|
|
81 |
parser_key="lines" # Assuming the LLM outputs queries line by line
|
82 |
)
|
83 |
|
84 |
+
# Create QA chain with multi-query retriever
|
85 |
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
|
86 |
|
87 |
def load_text(file_path: str) -> str:
|