File size: 3,984 Bytes
e3551a8
 
94700b8
2355280
6724fb5
 
5cb6a2f
6724fb5
b88e12c
 
 
94700b8
b88e12c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86c1fa0
e3551a8
b88e12c
5cb6a2f
 
e3551a8
 
 
1a67fa6
3b9f50e
e3551a8
 
 
 
 
 
 
 
 
b88e12c
e3551a8
 
b88e12c
6724fb5
 
b88e12c
 
e3551a8
 
 
 
b88e12c
e3551a8
 
3b9f50e
e3551a8
 
 
 
 
 
 
1db0375
1a67fa6
6724fb5
 
 
 
 
e3551a8
 
 
 
 
 
 
 
1a67fa6
b88e12c
 
 
1a67fa6
e3551a8
 
 
b88e12c
e3551a8
5bcac9d
 
6724fb5
5bcac9d
 
 
 
 
e3551a8
b88e12c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from langchain.schema import HumanMessage
from output_parser import attachment_parser, bigfive_parser, personality_parser
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from llm_loader import load_model
from config import openai_api_key
from langchain.chains import RetrievalQA
import json
import os

# Initialize embeddings and FAISS index
embedding_model = OpenAIEmbeddings(model="text-embedding-3-large", openai_api_key=openai_api_key)

# Path to the knowledge files
knowledge_files = {
    "attachments": "knowledge/bartholomew_attachments_definitions.txt",
    "bigfive": "knowledge/bigfive_definitions.txt",
    "personalities": "knowledge/personalities_definitions.txt"
}

# Load the content of knowledge files and create a list of documents
documents = []
for key, file_path in knowledge_files.items():
    with open(file_path, 'r', encoding='utf-8') as file:
        content = file.read().strip()
        documents.append(content)

# Create a FAISS index from the knowledge documents
faiss_index = FAISS.from_texts(documents, embedding_model)

# Save FAISS index locally (optional, in case you want to persist it)
faiss_index.save_local("faiss_index")

# If you want to load the FAISS index later, use this:
# faiss_index = FAISS.load_local("faiss_index", embedding_model)

# Load the LLM using llm_loader.py
llm = load_model(openai_api_key)  # Load the model using your custom loader

# Initialize the retrieval chain
retriever = faiss_index.as_retriever()
qa_chain = RetrievalQA.from_llm(llm=llm, retriever=retriever)

def load_text(file_path: str) -> str:
    with open(file_path, 'r', encoding='utf-8') as file:
        return file.read().strip()

def truncate_text(text: str, max_tokens: int = 10000) -> str:
    words = text.split()
    if len(words) > max_tokens:
        truncated_text = ' '.join(words[:max_tokens])
        print(f"Text truncated from {len(words)} to {max_tokens} words")
        return truncated_text
    print(f"Text not truncated, contains {len(words)} words")
    return text

def process_task(llm, input_text: str, general_task: str, specific_task: str, output_parser, qa_chain):
    truncated_input = truncate_text(input_text)

    # Perform retrieval to get the most relevant context
    relevant_docs = qa_chain({"query": truncated_input})
    retrieved_knowledge = "\n".join([doc.page_content for doc in relevant_docs['documents']])

    # Combine the retrieved knowledge with the original prompt
    prompt = f"""{general_task}

{specific_task}

Retrieved Knowledge: {retrieved_knowledge}

Input: {truncated_input}

{output_parser.get_format_instructions()}

Analysis:"""

    messages = [HumanMessage(content=prompt)]
    response = llm(messages)
    print(response)

    try:
        # Parse the response as JSON
        parsed_json = json.loads(response.content)
        
        # Validate and convert each item in the list
        parsed_output = [output_parser.parse_object(item) for item in parsed_json]
        return parsed_output
    except Exception as e:
        print(f"Error parsing output: {e}")
        return None

def process_input(input_text: str, llm):
    general_task = load_text("tasks/general_task.txt")

    tasks = [
        ("attachments", "tasks/Attachments_task.txt", attachment_parser),
        ("bigfive", "tasks/BigFive_task.txt", bigfive_parser),
        ("personalities", "tasks/Personalities_task.txt", personality_parser)
    ]

    results = {}

    for task_name, task_file, parser in tasks:
        specific_task = load_text(task_file)
        task_result = process_task(llm, input_text, general_task, specific_task, parser, qa_chain)
        
        if task_result:
            for i, speaker_result in enumerate(task_result):
                speaker_id = f"Speaker {i+1}"
                if speaker_id not in results:
                    results[speaker_id] = {}
                results[speaker_id][task_name] = speaker_result

    return results