reab5555 commited on
Commit
bac36b6
·
verified ·
1 Parent(s): 049120c

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +21 -7
processing.py CHANGED
@@ -4,8 +4,9 @@ from langchain_openai import OpenAIEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model
6
  from config import openai_api_key
7
- from langchain.chains import RetrievalQA
8
  from langchain.retrievers import MultiQueryRetriever
 
9
  import os
10
  import json
11
 
@@ -41,14 +42,27 @@ text_retriever = text_faiss_index.as_retriever()
41
  attachments_retriever = attachments_faiss_index.as_retriever()
42
  personalities_retriever = personalities_faiss_index.as_retriever()
43
 
44
- # Combine retrievers
45
- combined_retriever = MultiQueryRetriever(
46
- retrievers=[text_retriever, attachments_retriever, personalities_retriever],
47
- llm=llm
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
49
 
50
- # Create QA chain with combined retriever
51
- qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=combined_retriever)
52
 
53
  def load_text(file_path: str) -> str:
54
  with open(file_path, 'r', encoding='utf-8') as file:
 
4
  from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model
6
  from config import openai_api_key
7
+ from langchain.chains import RetrievalQA, LLMChain
8
  from langchain.retrievers import MultiQueryRetriever
9
+ from langchain.prompts import PromptTemplate
10
  import os
11
  import json
12
 
 
42
  attachments_retriever = attachments_faiss_index.as_retriever()
43
  personalities_retriever = personalities_faiss_index.as_retriever()
44
 
45
+ # Create a combined retriever
46
+ combined_retriever = text_retriever.add_retrievers([attachments_retriever, personalities_retriever])
47
+
48
+ # Create prompt template for query generation
49
+ prompt_template = PromptTemplate(
50
+ input_variables=["question"],
51
+ template="Generate multiple search queries for the following question: {question}"
52
+ )
53
+
54
+ # Create LLM chain for query generation
55
+ llm_chain = LLMChain(llm=llm, prompt=prompt_template)
56
+
57
+ # Initialize MultiQueryRetriever
58
+ multi_query_retriever = MultiQueryRetriever(
59
+ retriever=combined_retriever,
60
+ llm_chain=llm_chain,
61
+ parser_key="lines" # Assuming the LLM outputs queries line by line
62
  )
63
 
64
+ # Create QA chain with multi-query retriever
65
+ qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
66
 
67
  def load_text(file_path: str) -> str:
68
  with open(file_path, 'r', encoding='utf-8') as file: