reab5555 commited on
Commit
5cb6a2f
·
verified ·
1 Parent(s): 82400d9

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +5 -4
processing.py CHANGED
@@ -4,7 +4,7 @@ from langchain_community.embeddings import OpenAIEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model # Import the function to load the model
6
  from config import openai_api_key # Import the API key from config.py
7
- from langchain_community.chains import RetrievalQAChain # Assuming this is the updated class
8
  import os
9
 
10
  # Initialize embeddings and FAISS index
@@ -37,7 +37,8 @@ faiss_index.save_local("faiss_index")
37
  llm = load_model(openai_api_key) # Load the model using your custom loader
38
 
39
  # Initialize the retrieval chain
40
- qa_chain = RetrievalQAChain(llm=llm, retriever=faiss_index.as_retriever()) # Replace with the correct chain class
 
41
 
42
  def load_text(file_path: str) -> str:
43
  with open(file_path, 'r', encoding='utf-8') as file:
@@ -56,8 +57,8 @@ def process_task(llm, input_text: str, general_task: str, specific_task: str, ou
56
  truncated_input = truncate_text(input_text)
57
 
58
  # Perform retrieval to get the most relevant context
59
- relevant_docs = qa_chain({"query": truncated_input})
60
- retrieved_knowledge = "\n".join([doc.page_content for doc in relevant_docs['documents']])
61
 
62
  # Combine the retrieved knowledge with the original prompt
63
  prompt = f"""{general_task}
 
4
  from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model # Import the function to load the model
6
  from config import openai_api_key # Import the API key from config.py
7
+ from langchain.chains import RetrievalQA
8
  import os
9
 
10
  # Initialize embeddings and FAISS index
 
37
  llm = load_model(openai_api_key) # Load the model using your custom loader
38
 
39
  # Initialize the retrieval chain
40
+ retriever = faiss_index.as_retriever()
41
+ qa_chain = RetrievalQA.from_llm(llm=llm, retriever=retriever)
42
 
43
  def load_text(file_path: str) -> str:
44
  with open(file_path, 'r', encoding='utf-8') as file:
 
57
  truncated_input = truncate_text(input_text)
58
 
59
  # Perform retrieval to get the most relevant context
60
+ query = f"{general_task}\n\n{specific_task}\n\n{truncated_input}"
61
+ retrieved_knowledge = qa_chain.run(query)
62
 
63
  # Combine the retrieved knowledge with the original prompt
64
  prompt = f"""{general_task}