Spaces:
Runtime error
Runtime error
Update processing.py
Browse files- processing.py +5 -4
processing.py
CHANGED
@@ -4,7 +4,7 @@ from langchain_community.embeddings import OpenAIEmbeddings
|
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
from llm_loader import load_model # Import the function to load the model
|
6 |
from config import openai_api_key # Import the API key from config.py
|
7 |
-
from
|
8 |
import os
|
9 |
|
10 |
# Initialize embeddings and FAISS index
|
@@ -37,7 +37,8 @@ faiss_index.save_local("faiss_index")
|
|
37 |
llm = load_model(openai_api_key) # Load the model using your custom loader
|
38 |
|
39 |
# Initialize the retrieval chain
|
40 |
-
|
|
|
41 |
|
42 |
def load_text(file_path: str) -> str:
|
43 |
with open(file_path, 'r', encoding='utf-8') as file:
|
@@ -56,8 +57,8 @@ def process_task(llm, input_text: str, general_task: str, specific_task: str, ou
|
|
56 |
truncated_input = truncate_text(input_text)
|
57 |
|
58 |
# Perform retrieval to get the most relevant context
|
59 |
-
|
60 |
-
retrieved_knowledge =
|
61 |
|
62 |
# Combine the retrieved knowledge with the original prompt
|
63 |
prompt = f"""{general_task}
|
|
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
from llm_loader import load_model # Import the function to load the model
|
6 |
from config import openai_api_key # Import the API key from config.py
|
7 |
+
from langchain.chains import RetrievalQA
|
8 |
import os
|
9 |
|
10 |
# Initialize embeddings and FAISS index
|
|
|
37 |
llm = load_model(openai_api_key) # Load the model using your custom loader
|
38 |
|
39 |
# Initialize the retrieval chain
|
40 |
+
retriever = faiss_index.as_retriever()
|
41 |
+
qa_chain = RetrievalQA.from_llm(llm=llm, retriever=retriever)
|
42 |
|
43 |
def load_text(file_path: str) -> str:
|
44 |
with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
57 |
truncated_input = truncate_text(input_text)
|
58 |
|
59 |
# Perform retrieval to get the most relevant context
|
60 |
+
query = f"{general_task}\n\n{specific_task}\n\n{truncated_input}"
|
61 |
+
retrieved_knowledge = qa_chain.run(query)
|
62 |
|
63 |
# Combine the retrieved knowledge with the original prompt
|
64 |
prompt = f"""{general_task}
|