reab5555 commited on
Commit
c8ea22d
·
verified ·
1 Parent(s): bac36b6

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +19 -4
processing.py CHANGED
@@ -5,7 +5,8 @@ from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model
6
  from config import openai_api_key
7
  from langchain.chains import RetrievalQA, LLMChain
8
- from langchain.retrievers import MultiQueryRetriever
 
9
  from langchain.prompts import PromptTemplate
10
  import os
11
  import json
@@ -42,8 +43,22 @@ text_retriever = text_faiss_index.as_retriever()
42
  attachments_retriever = attachments_faiss_index.as_retriever()
43
  personalities_retriever = personalities_faiss_index.as_retriever()
44
 
45
- # Create a combined retriever
46
- combined_retriever = text_retriever.add_retrievers([attachments_retriever, personalities_retriever])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # Create prompt template for query generation
49
  prompt_template = PromptTemplate(
@@ -61,7 +76,7 @@ multi_query_retriever = MultiQueryRetriever(
61
  parser_key="lines" # Assuming the LLM outputs queries line by line
62
  )
63
 
64
- # Create QA chain with multi-query retriever
65
  qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
66
 
67
  def load_text(file_path: str) -> str:
 
5
  from llm_loader import load_model
6
  from config import openai_api_key
7
  from langchain.chains import RetrievalQA, LLMChain
8
+ from langchain.retrievers import MultiQueryRetriever, MultiVectorRetriever
9
+ from langchain.schema import Document
10
  from langchain.prompts import PromptTemplate
11
  import os
12
  import json
 
43
  attachments_retriever = attachments_faiss_index.as_retriever()
44
  personalities_retriever = personalities_faiss_index.as_retriever()
45
 
46
+ # Create a list of all retrievers
47
+ all_retrievers = [text_retriever, attachments_retriever, personalities_retriever]
48
+
49
+ # Create a function to combine results from all retrievers
50
+ def combined_retriever_function(query):
51
+ combined_docs = []
52
+ for retriever in all_retrievers:
53
+ docs = retriever.get_relevant_documents(query)
54
+ combined_docs.extend(docs)
55
+ return combined_docs
56
+
57
+ # Create a MultiVectorRetriever
58
+ combined_retriever = MultiVectorRetriever(
59
+ retrievers=all_retrievers,
60
+ retriever_query=combined_retriever_function
61
+ )
62
 
63
  # Create prompt template for query generation
64
  prompt_template = PromptTemplate(
 
76
  parser_key="lines" # Assuming the LLM outputs queries line by line
77
  )
78
 
79
+ # Use the multi-query retriever in the QA chain
80
  qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
81
 
82
  def load_text(file_path: str) -> str: