reab5555's picture
Update processing.py
6724fb5 verified
raw
history blame
3.98 kB
from langchain.schema import HumanMessage
from output_parser import attachment_parser, bigfive_parser, personality_parser
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from llm_loader import load_model
from config import openai_api_key
from langchain.chains import RetrievalQA
import json
import os
# Initialize embeddings and FAISS index
embedding_model = OpenAIEmbeddings(model="text-embedding-3-large", openai_api_key=openai_api_key)
# Path to the knowledge files
knowledge_files = {
"attachments": "knowledge/bartholomew_attachments_definitions.txt",
"bigfive": "knowledge/bigfive_definitions.txt",
"personalities": "knowledge/personalities_definitions.txt"
}
# Load the content of knowledge files and create a list of documents
documents = []
for key, file_path in knowledge_files.items():
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read().strip()
documents.append(content)
# Create a FAISS index from the knowledge documents
faiss_index = FAISS.from_texts(documents, embedding_model)
# Save FAISS index locally (optional, in case you want to persist it)
faiss_index.save_local("faiss_index")
# If you want to load the FAISS index later, use this:
# faiss_index = FAISS.load_local("faiss_index", embedding_model)
# Load the LLM using llm_loader.py
llm = load_model(openai_api_key) # Load the model using your custom loader
# Initialize the retrieval chain
retriever = faiss_index.as_retriever()
qa_chain = RetrievalQA.from_llm(llm=llm, retriever=retriever)
def load_text(file_path: str) -> str:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read().strip()
def truncate_text(text: str, max_tokens: int = 10000) -> str:
words = text.split()
if len(words) > max_tokens:
truncated_text = ' '.join(words[:max_tokens])
print(f"Text truncated from {len(words)} to {max_tokens} words")
return truncated_text
print(f"Text not truncated, contains {len(words)} words")
return text
def process_task(llm, input_text: str, general_task: str, specific_task: str, output_parser, qa_chain):
truncated_input = truncate_text(input_text)
# Perform retrieval to get the most relevant context
relevant_docs = qa_chain({"query": truncated_input})
retrieved_knowledge = "\n".join([doc.page_content for doc in relevant_docs['documents']])
# Combine the retrieved knowledge with the original prompt
prompt = f"""{general_task}
{specific_task}
Retrieved Knowledge: {retrieved_knowledge}
Input: {truncated_input}
{output_parser.get_format_instructions()}
Analysis:"""
messages = [HumanMessage(content=prompt)]
response = llm(messages)
print(response)
try:
# Parse the response as JSON
parsed_json = json.loads(response.content)
# Validate and convert each item in the list
parsed_output = [output_parser.parse_object(item) for item in parsed_json]
return parsed_output
except Exception as e:
print(f"Error parsing output: {e}")
return None
def process_input(input_text: str, llm):
general_task = load_text("tasks/general_task.txt")
tasks = [
("attachments", "tasks/Attachments_task.txt", attachment_parser),
("bigfive", "tasks/BigFive_task.txt", bigfive_parser),
("personalities", "tasks/Personalities_task.txt", personality_parser)
]
results = {}
for task_name, task_file, parser in tasks:
specific_task = load_text(task_file)
task_result = process_task(llm, input_text, general_task, specific_task, parser, qa_chain)
if task_result:
for i, speaker_result in enumerate(task_result):
speaker_id = f"Speaker {i+1}"
if speaker_id not in results:
results[speaker_id] = {}
results[speaker_id][task_name] = speaker_result
return results