Spaces:
Runtime error
Runtime error
from langchain.schema import HumanMessage | |
from output_parser import attachment_parser, bigfive_parser, personality_parser | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from llm_loader import load_model | |
from config import openai_api_key | |
from langchain.chains import RetrievalQA | |
import json | |
import os | |
# Initialize embeddings and FAISS index | |
embedding_model = OpenAIEmbeddings(model="text-embedding-3-large", openai_api_key=openai_api_key) | |
# Path to the knowledge files | |
knowledge_files = { | |
"attachments": "knowledge/bartholomew_attachments_definitions.txt", | |
"bigfive": "knowledge/bigfive_definitions.txt", | |
"personalities": "knowledge/personalities_definitions.txt" | |
} | |
# Load the content of knowledge files and create a list of documents | |
documents = [] | |
for key, file_path in knowledge_files.items(): | |
with open(file_path, 'r', encoding='utf-8') as file: | |
content = file.read().strip() | |
documents.append(content) | |
# Create a FAISS index from the knowledge documents | |
faiss_index = FAISS.from_texts(documents, embedding_model) | |
# Save FAISS index locally (optional, in case you want to persist it) | |
faiss_index.save_local("faiss_index") | |
# If you want to load the FAISS index later, use this: | |
# faiss_index = FAISS.load_local("faiss_index", embedding_model) | |
# Load the LLM using llm_loader.py | |
llm = load_model(openai_api_key) # Load the model using your custom loader | |
# Initialize the retrieval chain | |
retriever = faiss_index.as_retriever() | |
qa_chain = RetrievalQA.from_llm(llm=llm, retriever=retriever) | |
def load_text(file_path: str) -> str: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
return file.read().strip() | |
def truncate_text(text: str, max_tokens: int = 10000) -> str: | |
words = text.split() | |
if len(words) > max_tokens: | |
truncated_text = ' '.join(words[:max_tokens]) | |
print(f"Text truncated from {len(words)} to {max_tokens} words") | |
return truncated_text | |
print(f"Text not truncated, contains {len(words)} words") | |
return text | |
def process_task(llm, input_text: str, general_task: str, specific_task: str, output_parser, qa_chain): | |
truncated_input = truncate_text(input_text) | |
# Perform retrieval to get the most relevant context | |
relevant_docs = qa_chain({"query": truncated_input}) | |
retrieved_knowledge = "\n".join([doc.page_content for doc in relevant_docs['documents']]) | |
# Combine the retrieved knowledge with the original prompt | |
prompt = f"""{general_task} | |
{specific_task} | |
Retrieved Knowledge: {retrieved_knowledge} | |
Input: {truncated_input} | |
{output_parser.get_format_instructions()} | |
Analysis:""" | |
messages = [HumanMessage(content=prompt)] | |
response = llm(messages) | |
print(response) | |
try: | |
# Parse the response as JSON | |
parsed_json = json.loads(response.content) | |
# Validate and convert each item in the list | |
parsed_output = [output_parser.parse_object(item) for item in parsed_json] | |
return parsed_output | |
except Exception as e: | |
print(f"Error parsing output: {e}") | |
return None | |
def process_input(input_text: str, llm): | |
general_task = load_text("tasks/general_task.txt") | |
tasks = [ | |
("attachments", "tasks/Attachments_task.txt", attachment_parser), | |
("bigfive", "tasks/BigFive_task.txt", bigfive_parser), | |
("personalities", "tasks/Personalities_task.txt", personality_parser) | |
] | |
results = {} | |
for task_name, task_file, parser in tasks: | |
specific_task = load_text(task_file) | |
task_result = process_task(llm, input_text, general_task, specific_task, parser, qa_chain) | |
if task_result: | |
for i, speaker_result in enumerate(task_result): | |
speaker_id = f"Speaker {i+1}" | |
if speaker_id not in results: | |
results[speaker_id] = {} | |
results[speaker_id][task_name] = speaker_result | |
return results | |