Spaces:
Runtime error
Runtime error
File size: 3,984 Bytes
e3551a8 94700b8 2355280 6724fb5 5cb6a2f 6724fb5 b88e12c 94700b8 b88e12c 86c1fa0 e3551a8 b88e12c 5cb6a2f e3551a8 1a67fa6 3b9f50e e3551a8 b88e12c e3551a8 b88e12c 6724fb5 b88e12c e3551a8 b88e12c e3551a8 3b9f50e e3551a8 1db0375 1a67fa6 6724fb5 e3551a8 1a67fa6 b88e12c 1a67fa6 e3551a8 b88e12c e3551a8 5bcac9d 6724fb5 5bcac9d e3551a8 b88e12c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from langchain.schema import HumanMessage
from output_parser import attachment_parser, bigfive_parser, personality_parser
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from llm_loader import load_model
from config import openai_api_key
from langchain.chains import RetrievalQA
import json
import os
# Initialize embeddings and FAISS index
embedding_model = OpenAIEmbeddings(model="text-embedding-3-large", openai_api_key=openai_api_key)
# Path to the knowledge files
knowledge_files = {
"attachments": "knowledge/bartholomew_attachments_definitions.txt",
"bigfive": "knowledge/bigfive_definitions.txt",
"personalities": "knowledge/personalities_definitions.txt"
}
# Load the content of knowledge files and create a list of documents
documents = []
for key, file_path in knowledge_files.items():
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read().strip()
documents.append(content)
# Create a FAISS index from the knowledge documents
faiss_index = FAISS.from_texts(documents, embedding_model)
# Save FAISS index locally (optional, in case you want to persist it)
faiss_index.save_local("faiss_index")
# If you want to load the FAISS index later, use this:
# faiss_index = FAISS.load_local("faiss_index", embedding_model)
# Load the LLM using llm_loader.py
llm = load_model(openai_api_key) # Load the model using your custom loader
# Initialize the retrieval chain
retriever = faiss_index.as_retriever()
qa_chain = RetrievalQA.from_llm(llm=llm, retriever=retriever)
def load_text(file_path: str) -> str:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read().strip()
def truncate_text(text: str, max_tokens: int = 10000) -> str:
words = text.split()
if len(words) > max_tokens:
truncated_text = ' '.join(words[:max_tokens])
print(f"Text truncated from {len(words)} to {max_tokens} words")
return truncated_text
print(f"Text not truncated, contains {len(words)} words")
return text
def process_task(llm, input_text: str, general_task: str, specific_task: str, output_parser, qa_chain):
truncated_input = truncate_text(input_text)
# Perform retrieval to get the most relevant context
relevant_docs = qa_chain({"query": truncated_input})
retrieved_knowledge = "\n".join([doc.page_content for doc in relevant_docs['documents']])
# Combine the retrieved knowledge with the original prompt
prompt = f"""{general_task}
{specific_task}
Retrieved Knowledge: {retrieved_knowledge}
Input: {truncated_input}
{output_parser.get_format_instructions()}
Analysis:"""
messages = [HumanMessage(content=prompt)]
response = llm(messages)
print(response)
try:
# Parse the response as JSON
parsed_json = json.loads(response.content)
# Validate and convert each item in the list
parsed_output = [output_parser.parse_object(item) for item in parsed_json]
return parsed_output
except Exception as e:
print(f"Error parsing output: {e}")
return None
def process_input(input_text: str, llm):
general_task = load_text("tasks/general_task.txt")
tasks = [
("attachments", "tasks/Attachments_task.txt", attachment_parser),
("bigfive", "tasks/BigFive_task.txt", bigfive_parser),
("personalities", "tasks/Personalities_task.txt", personality_parser)
]
results = {}
for task_name, task_file, parser in tasks:
specific_task = load_text(task_file)
task_result = process_task(llm, input_text, general_task, specific_task, parser, qa_chain)
if task_result:
for i, speaker_result in enumerate(task_result):
speaker_id = f"Speaker {i+1}"
if speaker_id not in results:
results[speaker_id] = {}
results[speaker_id][task_name] = speaker_result
return results
|