Spaces:
Runtime error
Runtime error
Update processing.py
Browse files- processing.py +11 -12
processing.py
CHANGED
@@ -2,9 +2,10 @@ from langchain.schema import HumanMessage
|
|
2 |
from output_parser import attachment_parser, bigfive_parser, personality_parser
|
3 |
from langchain_openai import OpenAIEmbeddings
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
-
from llm_loader import load_model
|
6 |
-
from config import openai_api_key
|
7 |
from langchain.chains import RetrievalQA
|
|
|
8 |
import os
|
9 |
|
10 |
# Initialize embeddings and FAISS index
|
@@ -57,8 +58,8 @@ def process_task(llm, input_text: str, general_task: str, specific_task: str, ou
|
|
57 |
truncated_input = truncate_text(input_text)
|
58 |
|
59 |
# Perform retrieval to get the most relevant context
|
60 |
-
|
61 |
-
retrieved_knowledge =
|
62 |
|
63 |
# Combine the retrieved knowledge with the original prompt
|
64 |
prompt = f"""{general_task}
|
@@ -78,7 +79,11 @@ Analysis:"""
|
|
78 |
print(response)
|
79 |
|
80 |
try:
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
return parsed_output
|
83 |
except Exception as e:
|
84 |
print(f"Error parsing output: {e}")
|
@@ -99,17 +104,11 @@ def process_input(input_text: str, llm):
|
|
99 |
specific_task = load_text(task_file)
|
100 |
task_result = process_task(llm, input_text, general_task, specific_task, parser, qa_chain)
|
101 |
|
102 |
-
|
103 |
-
if isinstance(task_result, list):
|
104 |
for i, speaker_result in enumerate(task_result):
|
105 |
speaker_id = f"Speaker {i+1}"
|
106 |
if speaker_id not in results:
|
107 |
results[speaker_id] = {}
|
108 |
results[speaker_id][task_name] = speaker_result
|
109 |
-
else:
|
110 |
-
# If it's not a list, assume it's for a single speaker
|
111 |
-
if "Speaker 1" not in results:
|
112 |
-
results["Speaker 1"] = {}
|
113 |
-
results["Speaker 1"][task_name] = task_result
|
114 |
|
115 |
return results
|
|
|
2 |
from output_parser import attachment_parser, bigfive_parser, personality_parser
|
3 |
from langchain_openai import OpenAIEmbeddings
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
+
from llm_loader import load_model
|
6 |
+
from config import openai_api_key
|
7 |
from langchain.chains import RetrievalQA
|
8 |
+
import json
|
9 |
import os
|
10 |
|
11 |
# Initialize embeddings and FAISS index
|
|
|
58 |
truncated_input = truncate_text(input_text)
|
59 |
|
60 |
# Perform retrieval to get the most relevant context
|
61 |
+
relevant_docs = qa_chain({"query": truncated_input})
|
62 |
+
retrieved_knowledge = "\n".join([doc.page_content for doc in relevant_docs['documents']])
|
63 |
|
64 |
# Combine the retrieved knowledge with the original prompt
|
65 |
prompt = f"""{general_task}
|
|
|
79 |
print(response)
|
80 |
|
81 |
try:
|
82 |
+
# Parse the response as JSON
|
83 |
+
parsed_json = json.loads(response.content)
|
84 |
+
|
85 |
+
# Validate and convert each item in the list
|
86 |
+
parsed_output = [output_parser.parse_object(item) for item in parsed_json]
|
87 |
return parsed_output
|
88 |
except Exception as e:
|
89 |
print(f"Error parsing output: {e}")
|
|
|
104 |
specific_task = load_text(task_file)
|
105 |
task_result = process_task(llm, input_text, general_task, specific_task, parser, qa_chain)
|
106 |
|
107 |
+
if task_result:
|
|
|
108 |
for i, speaker_result in enumerate(task_result):
|
109 |
speaker_id = f"Speaker {i+1}"
|
110 |
if speaker_id not in results:
|
111 |
results[speaker_id] = {}
|
112 |
results[speaker_id][task_name] = speaker_result
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
return results
|