reab5555 commited on
Commit
b88e12c
·
verified ·
1 Parent(s): b01b6c3

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +50 -14
processing.py CHANGED
@@ -1,13 +1,48 @@
1
- # processing.py
2
  from langchain.schema import HumanMessage
3
  from output_parser import attachment_parser, bigfive_parser, personality_parser
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
5
 
6
  def load_text(file_path: str) -> str:
7
  with open(file_path, 'r', encoding='utf-8') as file:
8
  return file.read().strip()
9
 
10
-
11
  def truncate_text(text: str, max_tokens: int = 10000) -> str:
12
  words = text.split()
13
  if len(words) > max_tokens:
@@ -17,15 +52,19 @@ def truncate_text(text: str, max_tokens: int = 10000) -> str:
17
  print(f"Text not truncated, contains {len(words)} words")
18
  return text
19
 
20
-
21
- def process_task(llm, input_text: str, general_task: str, specific_task: str, knowledge: str, output_parser):
22
  truncated_input = truncate_text(input_text)
23
 
 
 
 
 
 
24
  prompt = f"""{general_task}
25
 
26
  {specific_task}
27
 
28
- Knowledge: {knowledge}
29
 
30
  Input: {truncated_input}
31
 
@@ -44,22 +83,19 @@ Analysis:"""
44
  print(f"Error parsing output: {e}")
45
  return None
46
 
47
-
48
  def process_input(input_text: str, llm):
49
  general_task = load_text("tasks/general_task.txt")
50
 
51
  tasks = [
52
- ("attachments", "tasks/Attachments_task.txt", "knowledge/bartholomew_attachments_definitions.txt",
53
- attachment_parser),
54
- ("bigfive", "tasks/BigFive_task.txt", "knowledge/bigfive_definitions.txt", bigfive_parser),
55
- ("personalities", "tasks/Personalities_task.txt", "knowledge/personalities_definitions.txt", personality_parser)
56
  ]
57
 
58
  results = {}
59
 
60
- for task_name, task_file, knowledge_file, parser in tasks:
61
  specific_task = load_text(task_file)
62
- knowledge = load_text(knowledge_file)
63
- results[task_name] = process_task(llm, input_text, general_task, specific_task, knowledge, parser)
64
 
65
- return results
 
 
1
  from langchain.schema import HumanMessage
2
  from output_parser import attachment_parser, bigfive_parser, personality_parser
3
+ from langchain.embeddings import OpenAIEmbeddings
4
+ from langchain.vectorstores import FAISS
5
+ from langchain.retrievers import RetrievalQA
6
+ from llm_loader import load_model # Import the function to load the model
7
+ from config import openai_api_key # Import the API key from config.py
8
+ import os
9
+
10
+ # Initialize embeddings and FAISS index
11
+ embedding_model = OpenAIEmbeddings()
12
+
13
+ # Path to the knowledge files
14
+ knowledge_files = {
15
+ "attachments": "knowledge/bartholomew_attachments_definitions.txt",
16
+ "bigfive": "knowledge/bigfive_definitions.txt",
17
+ "personalities": "knowledge/personalities_definitions.txt"
18
+ }
19
+
20
+ # Load the content of knowledge files and create a list of documents
21
+ documents = []
22
+ for key, file_path in knowledge_files.items():
23
+ with open(file_path, 'r', encoding='utf-8') as file:
24
+ content = file.read().strip()
25
+ documents.append(content)
26
+
27
+ # Create a FAISS index from the knowledge documents
28
+ faiss_index = FAISS.from_texts(documents, embedding_model)
29
+
30
+ # Save FAISS index locally (optional, in case you want to persist it)
31
+ faiss_index.save_local("faiss_index")
32
+
33
+ # If you want to load the FAISS index later, use this:
34
+ # faiss_index = FAISS.load_local("faiss_index", embedding_model)
35
+
36
+ # Load the LLM using llm_loader.py
37
+ llm = load_model(openai_api_key) # Assuming load_model function takes the API key as an argument
38
 
39
+ # Initialize the retrieval chain
40
+ qa_chain = RetrievalQA(llm=llm, retriever=faiss_index.as_retriever())
41
 
42
  def load_text(file_path: str) -> str:
43
  with open(file_path, 'r', encoding='utf-8') as file:
44
  return file.read().strip()
45
 
 
46
  def truncate_text(text: str, max_tokens: int = 10000) -> str:
47
  words = text.split()
48
  if len(words) > max_tokens:
 
52
  print(f"Text not truncated, contains {len(words)} words")
53
  return text
54
 
55
+ def process_task(llm, input_text: str, general_task: str, specific_task: str, output_parser, qa_chain):
 
56
  truncated_input = truncate_text(input_text)
57
 
58
+ # Perform retrieval to get the most relevant context
59
+ relevant_docs = qa_chain({"query": truncated_input})
60
+ retrieved_knowledge = "\n".join([doc.page_content for doc in relevant_docs['documents']])
61
+
62
+ # Combine the retrieved knowledge with the original prompt
63
  prompt = f"""{general_task}
64
 
65
  {specific_task}
66
 
67
+ Retrieved Knowledge: {retrieved_knowledge}
68
 
69
  Input: {truncated_input}
70
 
 
83
  print(f"Error parsing output: {e}")
84
  return None
85
 
 
86
  def process_input(input_text: str, llm):
87
  general_task = load_text("tasks/general_task.txt")
88
 
89
  tasks = [
90
+ ("attachments", "tasks/Attachments_task.txt", attachment_parser),
91
+ ("bigfive", "tasks/BigFive_task.txt", bigfive_parser),
92
+ ("personalities", "tasks/Personalities_task.txt", personality_parser)
 
93
  ]
94
 
95
  results = {}
96
 
97
+ for task_name, task_file, parser in tasks:
98
  specific_task = load_text(task_file)
99
+ results[task_name] = process_task(llm, input_text, general_task, specific_task, parser, qa_chain)
 
100
 
101
+ return results