reab5555 commited on
Commit
f9eef9b
·
verified ·
1 Parent(s): c8ea22d

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +25 -20
processing.py CHANGED
@@ -1,13 +1,13 @@
1
- from langchain.schema import HumanMessage
2
  from output_parser import output_parser
3
  from langchain_openai import OpenAIEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model
6
  from config import openai_api_key
7
  from langchain.chains import RetrievalQA, LLMChain
8
- from langchain.retrievers import MultiQueryRetriever, MultiVectorRetriever
9
- from langchain.schema import Document
10
  from langchain.prompts import PromptTemplate
 
11
  import os
12
  import json
13
 
@@ -43,22 +43,27 @@ text_retriever = text_faiss_index.as_retriever()
43
  attachments_retriever = attachments_faiss_index.as_retriever()
44
  personalities_retriever = personalities_faiss_index.as_retriever()
45
 
46
- # Create a list of all retrievers
47
- all_retrievers = [text_retriever, attachments_retriever, personalities_retriever]
48
-
49
- # Create a function to combine results from all retrievers
50
- def combined_retriever_function(query):
51
- combined_docs = []
52
- for retriever in all_retrievers:
53
- docs = retriever.get_relevant_documents(query)
54
- combined_docs.extend(docs)
55
- return combined_docs
56
-
57
- # Create a MultiVectorRetriever
58
- combined_retriever = MultiVectorRetriever(
59
- retrievers=all_retrievers,
60
- retriever_query=combined_retriever_function
61
- )
 
 
 
 
 
62
 
63
  # Create prompt template for query generation
64
  prompt_template = PromptTemplate(
@@ -76,7 +81,7 @@ multi_query_retriever = MultiQueryRetriever(
76
  parser_key="lines" # Assuming the LLM outputs queries line by line
77
  )
78
 
79
- # Use the multi-query retriever in the QA chain
80
  qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
81
 
82
  def load_text(file_path: str) -> str:
 
1
+ from langchain.schema import HumanMessage, BaseRetriever
2
  from output_parser import output_parser
3
  from langchain_openai import OpenAIEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model
6
  from config import openai_api_key
7
  from langchain.chains import RetrievalQA, LLMChain
8
+ from langchain.retrievers import MultiQueryRetriever
 
9
  from langchain.prompts import PromptTemplate
10
+ from typing import List
11
  import os
12
  import json
13
 
 
43
  attachments_retriever = attachments_faiss_index.as_retriever()
44
  personalities_retriever = personalities_faiss_index.as_retriever()
45
 
46
+ # Create a custom combined retriever
47
+ class CombinedRetriever(BaseRetriever):
48
+ def __init__(self, retrievers: List[BaseRetriever]):
49
+ self.retrievers = retrievers
50
+
51
+ def get_relevant_documents(self, query: str):
52
+ combined_docs = []
53
+ for retriever in self.retrievers:
54
+ docs = retriever.get_relevant_documents(query)
55
+ combined_docs.extend(docs)
56
+ return combined_docs
57
+
58
+ async def aget_relevant_documents(self, query: str):
59
+ combined_docs = []
60
+ for retriever in self.retrievers:
61
+ docs = await retriever.aget_relevant_documents(query)
62
+ combined_docs.extend(docs)
63
+ return combined_docs
64
+
65
+ # Create an instance of the combined retriever
66
+ combined_retriever = CombinedRetriever([text_retriever, attachments_retriever, personalities_retriever])
67
 
68
  # Create prompt template for query generation
69
  prompt_template = PromptTemplate(
 
81
  parser_key="lines" # Assuming the LLM outputs queries line by line
82
  )
83
 
84
+ # Create QA chain with multi-query retriever
85
  qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
86
 
87
  def load_text(file_path: str) -> str: