reab5555 commited on
Commit
6724fb5
·
verified ·
1 Parent(s): 09abe1d

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +11 -12
processing.py CHANGED
@@ -2,9 +2,10 @@ from langchain.schema import HumanMessage
2
  from output_parser import attachment_parser, bigfive_parser, personality_parser
3
  from langchain_openai import OpenAIEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
- from llm_loader import load_model # Import the function to load the model
6
- from config import openai_api_key # Import the API key from config.py
7
  from langchain.chains import RetrievalQA
 
8
  import os
9
 
10
  # Initialize embeddings and FAISS index
@@ -57,8 +58,8 @@ def process_task(llm, input_text: str, general_task: str, specific_task: str, ou
57
  truncated_input = truncate_text(input_text)
58
 
59
  # Perform retrieval to get the most relevant context
60
- query = f"{general_task}\n\n{specific_task}\n\n{truncated_input}"
61
- retrieved_knowledge = qa_chain.run(query)
62
 
63
  # Combine the retrieved knowledge with the original prompt
64
  prompt = f"""{general_task}
@@ -78,7 +79,11 @@ Analysis:"""
78
  print(response)
79
 
80
  try:
81
- parsed_output = output_parser.parse(response.content)
 
 
 
 
82
  return parsed_output
83
  except Exception as e:
84
  print(f"Error parsing output: {e}")
@@ -99,17 +104,11 @@ def process_input(input_text: str, llm):
99
  specific_task = load_text(task_file)
100
  task_result = process_task(llm, input_text, general_task, specific_task, parser, qa_chain)
101
 
102
- # If the result is a list, assume it's for multiple speakers
103
- if isinstance(task_result, list):
104
  for i, speaker_result in enumerate(task_result):
105
  speaker_id = f"Speaker {i+1}"
106
  if speaker_id not in results:
107
  results[speaker_id] = {}
108
  results[speaker_id][task_name] = speaker_result
109
- else:
110
- # If it's not a list, assume it's for a single speaker
111
- if "Speaker 1" not in results:
112
- results["Speaker 1"] = {}
113
- results["Speaker 1"][task_name] = task_result
114
 
115
  return results
 
2
  from output_parser import attachment_parser, bigfive_parser, personality_parser
3
  from langchain_openai import OpenAIEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
+ from llm_loader import load_model
6
+ from config import openai_api_key
7
  from langchain.chains import RetrievalQA
8
+ import json
9
  import os
10
 
11
  # Initialize embeddings and FAISS index
 
58
  truncated_input = truncate_text(input_text)
59
 
60
  # Perform retrieval to get the most relevant context
61
+ relevant_docs = qa_chain({"query": truncated_input})
62
+ retrieved_knowledge = "\n".join([doc.page_content for doc in relevant_docs['documents']])
63
 
64
  # Combine the retrieved knowledge with the original prompt
65
  prompt = f"""{general_task}
 
79
  print(response)
80
 
81
  try:
82
+ # Parse the response as JSON
83
+ parsed_json = json.loads(response.content)
84
+
85
+ # Validate and convert each item in the list
86
+ parsed_output = [output_parser.parse_object(item) for item in parsed_json]
87
  return parsed_output
88
  except Exception as e:
89
  print(f"Error parsing output: {e}")
 
104
  specific_task = load_text(task_file)
105
  task_result = process_task(llm, input_text, general_task, specific_task, parser, qa_chain)
106
 
107
+ if task_result:
 
108
  for i, speaker_result in enumerate(task_result):
109
  speaker_id = f"Speaker {i+1}"
110
  if speaker_id not in results:
111
  results[speaker_id] = {}
112
  results[speaker_id][task_name] = speaker_result
 
 
 
 
 
113
 
114
  return results