from langchain.schema import HumanMessage from output_parser import attachment_parser, bigfive_parser, personality_parser from langchain_openai import OpenAIEmbeddings from langchain_community.vectorstores import FAISS from llm_loader import load_model from config import openai_api_key from langchain.chains import RetrievalQA import json import os # Initialize embeddings and FAISS index embedding_model = OpenAIEmbeddings(model="text-embedding-3-large", openai_api_key=openai_api_key) # Path to the knowledge files knowledge_files = { "attachments": "knowledge/bartholomew_attachments_definitions.txt", "bigfive": "knowledge/bigfive_definitions.txt", "personalities": "knowledge/personalities_definitions.txt" } # Load the content of knowledge files and create a list of documents documents = [] for key, file_path in knowledge_files.items(): with open(file_path, 'r', encoding='utf-8') as file: content = file.read().strip() documents.append(content) # Create a FAISS index from the knowledge documents faiss_index = FAISS.from_texts(documents, embedding_model) # Save FAISS index locally (optional, in case you want to persist it) faiss_index.save_local("faiss_index") # If you want to load the FAISS index later, use this: # faiss_index = FAISS.load_local("faiss_index", embedding_model) # Load the LLM using llm_loader.py llm = load_model(openai_api_key) # Load the model using your custom loader # Initialize the retrieval chain retriever = faiss_index.as_retriever() qa_chain = RetrievalQA.from_llm(llm=llm, retriever=retriever) def load_text(file_path: str) -> str: with open(file_path, 'r', encoding='utf-8') as file: return file.read().strip() def truncate_text(text: str, max_tokens: int = 10000) -> str: words = text.split() if len(words) > max_tokens: truncated_text = ' '.join(words[:max_tokens]) print(f"Text truncated from {len(words)} to {max_tokens} words") return truncated_text print(f"Text not truncated, contains {len(words)} words") return text def process_task(llm, input_text: str, general_task: str, specific_task: str, output_parser, qa_chain): truncated_input = truncate_text(input_text) # Perform retrieval to get the most relevant context relevant_docs = qa_chain({"query": truncated_input}) retrieved_knowledge = "\n".join([doc.page_content for doc in relevant_docs['documents']]) # Combine the retrieved knowledge with the original prompt prompt = f"""{general_task} {specific_task} Retrieved Knowledge: {retrieved_knowledge} Input: {truncated_input} {output_parser.get_format_instructions()} Analysis:""" messages = [HumanMessage(content=prompt)] response = llm(messages) print(response) try: # Parse the response as JSON parsed_json = json.loads(response.content) # Validate and convert each item in the list parsed_output = [output_parser.parse_object(item) for item in parsed_json] return parsed_output except Exception as e: print(f"Error parsing output: {e}") return None def process_input(input_text: str, llm): general_task = load_text("tasks/general_task.txt") tasks = [ ("attachments", "tasks/Attachments_task.txt", attachment_parser), ("bigfive", "tasks/BigFive_task.txt", bigfive_parser), ("personalities", "tasks/Personalities_task.txt", personality_parser) ] results = {} for task_name, task_file, parser in tasks: specific_task = load_text(task_file) task_result = process_task(llm, input_text, general_task, specific_task, parser, qa_chain) if task_result: for i, speaker_result in enumerate(task_result): speaker_id = f"Speaker {i+1}" if speaker_id not in results: results[speaker_id] = {} results[speaker_id][task_name] = speaker_result return results