reab5555 commited on
Commit
e913097
·
verified ·
1 Parent(s): 9720614

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +27 -20
processing.py CHANGED
@@ -6,8 +6,8 @@ from llm_loader import load_model, count_tokens
6
  from config import openai_api_key
7
  from langchain.chains import RetrievalQA
8
  from langchain.prompts import PromptTemplate
 
9
  from langchain.schema.runnable import RunnablePassthrough
10
- from langchain_core.runnables import RunnableLambda
11
  from typing import List, Any, Optional
12
  from pydantic import Field
13
  from langchain_core.callbacks import CallbackManagerForRetrieverRun
@@ -35,7 +35,7 @@ for key, file_path in knowledge_files.items():
35
  text_faiss_index = FAISS.from_texts(documents, embedding_model)
36
 
37
  # Load pre-existing FAISS indexes
38
- attachments_faiss_index = FAISS.load_local("knowledge/faiss_index_Attachments_db", embedding_model, allow_dangerous_deserialization=True)
39
  personalities_faiss_index = FAISS.load_local("knowledge/faiss_index_Personalities_db", embedding_model, allow_dangerous_deserialization=True)
40
 
41
  # Initialize LLM
@@ -75,12 +75,15 @@ combined_retriever = CombinedRetriever(retrievers=[text_retriever, attachments_r
75
 
76
  # Create prompt template for query generation
77
  prompt_template = PromptTemplate(
78
- input_variables=["question"],
79
- template="Generate multiple search queries for the following question: {question}"
80
  )
81
 
82
  # Create query generation chain
83
- query_generation_chain = prompt_template | llm
 
 
 
84
 
85
  # Create multi-query retrieval chain
86
  def generate_queries(input):
@@ -97,10 +100,10 @@ def multi_query_retrieve(input):
97
 
98
  multi_query_retriever = RunnableLambda(multi_query_retrieve)
99
 
100
- # Create QA chain with multi-query retriever
101
  def format_docs(docs):
102
  return "\n\n".join(doc.page_content for doc in docs["documents"])
103
 
 
104
  qa_chain = (
105
  {
106
  "context": multi_query_retriever | format_docs,
@@ -129,25 +132,33 @@ def process_input(input_text: str, llm):
129
 
130
  truncated_input = truncate_text(input_text)
131
 
132
- relevant_docs = qa_chain.invoke({"query": truncated_input})
 
133
 
134
  # Print the generated queries
135
  print("Generated Queries:")
136
- for query in relevant_docs["retrieval_results"]["queries"]:
137
  print(f"- {query}")
138
 
139
  # Print the retrieved documents
140
  print("\nRetrieved Documents:")
141
- for i, doc in enumerate(relevant_docs["retrieval_results"]["documents"], 1):
142
  print(f"Document {i}:")
143
- print(f"Content: {doc.page_content}...") # Print first 200 characters
144
  print("-" * 50)
145
-
146
- retrieved_knowledge = str(relevant_docs["llm_output"])
 
 
 
 
 
 
 
147
 
148
  prompt = f"""
149
  {general_task}
150
- Genral Impression Task:
151
  {general_impression_task}
152
  Attachment Styles Task:
153
  {attachments_task}
@@ -165,11 +176,6 @@ Please provide a comprehensive analysis for each speaker, including:
165
  Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all four aspects mentioned above, however, General impressions must not be in json or dict format.
166
  Analysis:"""
167
 
168
- #truncated_input_tokents_count = count_tokens(truncated_input)
169
- #print('truncated_input_tokents_count:', truncated_input_tokents_count)
170
- #input_tokens_count = count_tokens(prompt)
171
- #print('input_tokens_count', input_tokens_count)
172
-
173
  response = llm.invoke(prompt)
174
 
175
  print("Raw LLM Model Output:")
@@ -190,7 +196,7 @@ Analysis:"""
190
  speaker_id = f"Speaker {i}"
191
  parsed_analysis = output_parser.parse_speaker_analysis(speaker_analysis)
192
 
193
- # Convert general_impression to string if it's a dict or JSON object
194
  general_impression = parsed_analysis.general_impression
195
  if isinstance(general_impression, dict):
196
  general_impression = json.dumps(general_impression)
@@ -229,4 +235,5 @@ Analysis:"""
229
  'attachments': empty_analysis.attachment_style,
230
  'bigfive': empty_analysis.big_five_traits,
231
  'personalities': empty_analysis.personality_disorder
232
- }}
 
 
6
  from config import openai_api_key
7
  from langchain.chains import RetrievalQA
8
  from langchain.prompts import PromptTemplate
9
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
10
  from langchain.schema.runnable import RunnablePassthrough
 
11
  from typing import List, Any, Optional
12
  from pydantic import Field
13
  from langchain_core.callbacks import CallbackManagerForRetrieverRun
 
35
  text_faiss_index = FAISS.from_texts(documents, embedding_model)
36
 
37
  # Load pre-existing FAISS indexes
38
+ attachments_faiss_index = FAISS.load_local("knowledge/faiss_index_Attachment_db", embedding_model, allow_dangerous_deserialization=True)
39
  personalities_faiss_index = FAISS.load_local("knowledge/faiss_index_Personalities_db", embedding_model, allow_dangerous_deserialization=True)
40
 
41
  # Initialize LLM
 
75
 
76
  # Create prompt template for query generation
77
  prompt_template = PromptTemplate(
78
+ input_variables=["question", "context"],
79
+ template="Use the following context to answer the question: {context}\n\nQuestion: {question}\nAnswer:"
80
  )
81
 
82
  # Create query generation chain
83
+ query_generation_chain = PromptTemplate(
84
+ input_variables=["question"],
85
+ template="Generate multiple search queries for the following question: {question}"
86
+ ) | llm
87
 
88
  # Create multi-query retrieval chain
89
  def generate_queries(input):
 
100
 
101
  multi_query_retriever = RunnableLambda(multi_query_retrieve)
102
 
 
103
  def format_docs(docs):
104
  return "\n\n".join(doc.page_content for doc in docs["documents"])
105
 
106
+ # Create QA chain with multi-query retriever
107
  qa_chain = (
108
  {
109
  "context": multi_query_retriever | format_docs,
 
132
 
133
  truncated_input = truncate_text(input_text)
134
 
135
+ # Get the retrieval results and LLM output
136
+ retrieval_result = multi_query_retrieve(truncated_input)
137
 
138
  # Print the generated queries
139
  print("Generated Queries:")
140
+ for query in retrieval_result["queries"]:
141
  print(f"- {query}")
142
 
143
  # Print the retrieved documents
144
  print("\nRetrieved Documents:")
145
+ for i, doc in enumerate(retrieval_result["documents"], 1):
146
  print(f"Document {i}:")
147
+ print(f"Content: {doc.page_content")
148
  print("-" * 50)
149
+
150
+ # Format the retrieved documents
151
+ formatted_docs = format_docs(retrieval_result)
152
+
153
+ # Generate the LLM response
154
+ llm_input = prompt_template.format(question=truncated_input, context=formatted_docs)
155
+ llm_output = llm.invoke(llm_input)
156
+
157
+ retrieved_knowledge = str(llm_output.content)
158
 
159
  prompt = f"""
160
  {general_task}
161
+ General Impression Task:
162
  {general_impression_task}
163
  Attachment Styles Task:
164
  {attachments_task}
 
176
  Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all four aspects mentioned above, however, General impressions must not be in json or dict format.
177
  Analysis:"""
178
 
 
 
 
 
 
179
  response = llm.invoke(prompt)
180
 
181
  print("Raw LLM Model Output:")
 
196
  speaker_id = f"Speaker {i}"
197
  parsed_analysis = output_parser.parse_speaker_analysis(speaker_analysis)
198
 
199
+ # Convert general_impression to string if it's a dict or JSON object
200
  general_impression = parsed_analysis.general_impression
201
  if isinstance(general_impression, dict):
202
  general_impression = json.dumps(general_impression)
 
235
  'attachments': empty_analysis.attachment_style,
236
  'bigfive': empty_analysis.big_five_traits,
237
  'personalities': empty_analysis.personality_disorder
238
+ }}
239
+