reab5555 commited on
Commit
8948925
·
verified ·
1 Parent(s): a9fd016

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +11 -47
processing.py CHANGED
@@ -8,10 +8,8 @@ from langchain.chains import RetrievalQA
8
  import os
9
  import json
10
 
11
- # Initialize embeddings and FAISS index
12
  embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
13
 
14
- # Load the content of knowledge files
15
  knowledge_files = {
16
  "attachments": "knowledge/bartholomew_attachments_definitions.txt",
17
  "bigfive": "knowledge/bigfive_definitions.txt",
@@ -24,13 +22,10 @@ for key, file_path in knowledge_files.items():
24
  content = file.read().strip()
25
  documents.append(content)
26
 
27
- # Create a FAISS index from the knowledge documents
28
  faiss_index = FAISS.from_texts(documents, embedding_model)
29
 
30
- # Load the LLM
31
  llm = load_model(openai_api_key)
32
 
33
- # Initialize the retrieval chain
34
  qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=faiss_index.as_retriever())
35
 
36
  def load_text(file_path: str) -> str:
@@ -40,10 +35,7 @@ def load_text(file_path: str) -> str:
40
  def truncate_text(text: str, max_tokens: int = 10000) -> str:
41
  words = text.split()
42
  if len(words) > max_tokens:
43
- truncated_text = ' '.join(words[:max_tokens])
44
- print(f"Text truncated from {len(words)} to {max_tokens} words")
45
- return truncated_text
46
- print(f"Text not truncated, contains {len(words)} words")
47
  return text
48
 
49
  def process_input(input_text: str, llm):
@@ -54,16 +46,13 @@ def process_input(input_text: str, llm):
54
 
55
  truncated_input = truncate_text(input_text)
56
 
57
- # Perform retrieval to get the most relevant context
58
  relevant_docs = qa_chain.invoke({"query": truncated_input})
59
 
60
- # Extract the retrieved knowledge
61
  if isinstance(relevant_docs, dict) and 'result' in relevant_docs:
62
  retrieved_knowledge = relevant_docs['result']
63
  else:
64
  retrieved_knowledge = str(relevant_docs)
65
 
66
- # Combine all tasks and knowledge into a single prompt
67
  prompt = f"""{general_task}
68
 
69
  Attachment Styles Task:
@@ -86,45 +75,28 @@ Please provide a comprehensive analysis for each speaker, including:
86
 
87
  Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all three aspects mentioned above.
88
 
89
- The speaker format should be "Speaker [number]", for example:
90
- Speaker [1]
91
- Secured: [probability]
92
- Anxious-Preoccupied: [probability]
93
- Dismissive-Avoidant: [probability]
94
- Fearful-Avoidant: [probability]
95
- Self: [rating]
96
- Others: [rating]
97
- Anxiety: [rating]
98
- Avoidance: [rating]
99
- Explanation: [very brief explanation]
100
-
101
  Analysis:"""
102
 
103
  messages = [HumanMessage(content=prompt)]
104
  response = llm.invoke(messages)
105
 
 
 
 
106
  try:
107
- # Remove code block markers if present
108
  content = response.content
109
  if content.startswith("```json"):
110
  content = content.split("```json", 1)[1]
111
  if content.endswith("```"):
112
  content = content.rsplit("```", 1)[0]
113
 
114
- # Parse the JSON
115
  parsed_json = json.loads(content.strip())
116
 
117
- # Process the parsed JSON
118
  results = {}
119
  speaker_analyses = parsed_json.get('speaker_analyses', [])
120
  for speaker_analysis in speaker_analyses:
121
  speaker_id = speaker_analysis.get('speaker', 'Unknown Speaker')
122
-
123
- # Extract speaker number
124
- speaker_number = speaker_id.split('[')[-1].split(']')[0]
125
- speaker_key = f"Speaker {speaker_number}"
126
-
127
- results[speaker_key] = {
128
  'attachments': attachment_parser.parse_object(speaker_analysis.get('attachment_styles', {})),
129
  'bigfive': bigfive_parser.parse_object(speaker_analysis.get('big_five_traits', {})),
130
  'personalities': personality_parser.parse_object(speaker_analysis.get('personality_disorders', {}))
@@ -139,18 +111,10 @@ Analysis:"""
139
  }}
140
 
141
  return results
142
- except json.JSONDecodeError as e:
143
- print(f"Error parsing JSON: {e}")
144
- print("Raw content causing the error:")
145
- print(response.content)
146
  except Exception as e:
147
- print(f"Unexpected error: {e}")
148
- print("Raw content:")
149
- print(response.content)
150
-
151
- # If any error occurs, return a default result
152
- return {"Unknown Speaker": {
153
- 'attachments': attachment_parser.parse_object({}),
154
- 'bigfive': bigfive_parser.parse_object({}),
155
- 'personalities': personality_parser.parse_object({})
156
- }}
 
8
  import os
9
  import json
10
 
 
11
  embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
12
 
 
13
  knowledge_files = {
14
  "attachments": "knowledge/bartholomew_attachments_definitions.txt",
15
  "bigfive": "knowledge/bigfive_definitions.txt",
 
22
  content = file.read().strip()
23
  documents.append(content)
24
 
 
25
  faiss_index = FAISS.from_texts(documents, embedding_model)
26
 
 
27
  llm = load_model(openai_api_key)
28
 
 
29
  qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=faiss_index.as_retriever())
30
 
31
  def load_text(file_path: str) -> str:
 
35
  def truncate_text(text: str, max_tokens: int = 10000) -> str:
36
  words = text.split()
37
  if len(words) > max_tokens:
38
+ return ' '.join(words[:max_tokens])
 
 
 
39
  return text
40
 
41
  def process_input(input_text: str, llm):
 
46
 
47
  truncated_input = truncate_text(input_text)
48
 
 
49
  relevant_docs = qa_chain.invoke({"query": truncated_input})
50
 
 
51
  if isinstance(relevant_docs, dict) and 'result' in relevant_docs:
52
  retrieved_knowledge = relevant_docs['result']
53
  else:
54
  retrieved_knowledge = str(relevant_docs)
55
 
 
56
  prompt = f"""{general_task}
57
 
58
  Attachment Styles Task:
 
75
 
76
  Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all three aspects mentioned above.
77
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  Analysis:"""
79
 
80
  messages = [HumanMessage(content=prompt)]
81
  response = llm.invoke(messages)
82
 
83
+ print("Raw LLM Model Output:")
84
+ print(response.content)
85
+
86
  try:
 
87
  content = response.content
88
  if content.startswith("```json"):
89
  content = content.split("```json", 1)[1]
90
  if content.endswith("```"):
91
  content = content.rsplit("```", 1)[0]
92
 
 
93
  parsed_json = json.loads(content.strip())
94
 
 
95
  results = {}
96
  speaker_analyses = parsed_json.get('speaker_analyses', [])
97
  for speaker_analysis in speaker_analyses:
98
  speaker_id = speaker_analysis.get('speaker', 'Unknown Speaker')
99
+ results[speaker_id] = {
 
 
 
 
 
100
  'attachments': attachment_parser.parse_object(speaker_analysis.get('attachment_styles', {})),
101
  'bigfive': bigfive_parser.parse_object(speaker_analysis.get('big_five_traits', {})),
102
  'personalities': personality_parser.parse_object(speaker_analysis.get('personality_disorders', {}))
 
111
  }}
112
 
113
  return results
 
 
 
 
114
  except Exception as e:
115
+ print(f"Error processing input: {e}")
116
+ return {"Unknown Speaker": {
117
+ 'attachments': attachment_parser.parse_object({}),
118
+ 'bigfive': bigfive_parser.parse_object({}),
119
+ 'personalities': personality_parser.parse_object({})
120
+ }}