reab5555 commited on
Commit
c07391a
·
verified ·
1 Parent(s): f497377

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +15 -41
processing.py CHANGED
@@ -7,7 +7,6 @@ from config import openai_api_key
7
  from langchain.chains import RetrievalQA
8
  from langchain.prompts import PromptTemplate
9
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
10
- from langchain.schema.runnable import RunnablePassthrough
11
  from typing import List, Any, Optional
12
  from pydantic import Field
13
  from langchain_core.callbacks import CallbackManagerForRetrieverRun
@@ -75,15 +74,12 @@ combined_retriever = CombinedRetriever(retrievers=[text_retriever, attachments_r
75
 
76
  # Create prompt template for query generation
77
  prompt_template = PromptTemplate(
78
- input_variables=["question", "context"],
79
- template="Use the following context to answer the question: {context}\n\nQuestion: {question}\nAnswer:"
80
  )
81
 
82
  # Create query generation chain
83
- query_generation_chain = PromptTemplate(
84
- input_variables=["question"],
85
- template="Generate multiple search queries for the following question: {question}"
86
- ) | llm
87
 
88
  # Create multi-query retrieval chain
89
  def generate_queries(input):
@@ -96,19 +92,13 @@ def multi_query_retrieve(input):
96
  for query in queries:
97
  docs = combined_retriever.get_relevant_documents(query)
98
  all_docs.extend(docs)
99
- return {"queries": queries, "documents": all_docs}
100
 
101
  multi_query_retriever = RunnableLambda(multi_query_retrieve)
102
 
103
- def format_docs(docs):
104
- return "\n\n".join(doc.page_content for doc in docs["documents"])
105
-
106
  # Create QA chain with multi-query retriever
107
  qa_chain = (
108
- {
109
- "context": multi_query_retriever | format_docs,
110
- "question": RunnablePassthrough()
111
- }
112
  | prompt_template
113
  | llm
114
  )
@@ -132,33 +122,13 @@ def process_input(input_text: str, llm):
132
 
133
  truncated_input = truncate_text(input_text)
134
 
135
- # Get the retrieval results and LLM output
136
- retrieval_result = multi_query_retrieve(truncated_input)
137
 
138
- # Print the generated queries
139
- print("Generated Queries:")
140
- for query in retrieval_result["queries"]:
141
- print(f"- {query}")
142
-
143
- # Print the retrieved documents
144
- print("\nRetrieved Documents:")
145
- for i, doc in enumerate(retrieval_result["documents"], 1):
146
- print(f"Document {i}:")
147
- print(f"Content: {doc.page_content}")
148
- print("-" * 50)
149
-
150
- # Format the retrieved documents
151
- formatted_docs = format_docs(retrieval_result)
152
-
153
- # Generate the LLM response
154
- llm_input = prompt_template.format(question=truncated_input, context=formatted_docs)
155
- llm_output = llm.invoke(llm_input)
156
-
157
- retrieved_knowledge = str(llm_output.content)
158
 
159
  prompt = f"""
160
  {general_task}
161
- General Impression Task:
162
  {general_impression_task}
163
  Attachment Styles Task:
164
  {attachments_task}
@@ -176,6 +146,11 @@ Please provide a comprehensive analysis for each speaker, including:
176
  Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all four aspects mentioned above, however, General impressions must not be in json or dict format.
177
  Analysis:"""
178
 
 
 
 
 
 
179
  response = llm.invoke(prompt)
180
 
181
  print("Raw LLM Model Output:")
@@ -196,7 +171,7 @@ Analysis:"""
196
  speaker_id = f"Speaker {i}"
197
  parsed_analysis = output_parser.parse_speaker_analysis(speaker_analysis)
198
 
199
- # Convert general_impression to string if it's a dict or JSON object
200
  general_impression = parsed_analysis.general_impression
201
  if isinstance(general_impression, dict):
202
  general_impression = json.dumps(general_impression)
@@ -235,5 +210,4 @@ Analysis:"""
235
  'attachments': empty_analysis.attachment_style,
236
  'bigfive': empty_analysis.big_five_traits,
237
  'personalities': empty_analysis.personality_disorder
238
- }}
239
-
 
7
  from langchain.chains import RetrievalQA
8
  from langchain.prompts import PromptTemplate
9
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
 
10
  from typing import List, Any, Optional
11
  from pydantic import Field
12
  from langchain_core.callbacks import CallbackManagerForRetrieverRun
 
74
 
75
  # Create prompt template for query generation
76
  prompt_template = PromptTemplate(
77
+ input_variables=["question"],
78
+ template="Generate multiple search queries for the following question: {question}"
79
  )
80
 
81
  # Create query generation chain
82
+ query_generation_chain = prompt_template | llm
 
 
 
83
 
84
  # Create multi-query retrieval chain
85
  def generate_queries(input):
 
92
  for query in queries:
93
  docs = combined_retriever.get_relevant_documents(query)
94
  all_docs.extend(docs)
95
+ return all_docs
96
 
97
  multi_query_retriever = RunnableLambda(multi_query_retrieve)
98
 
 
 
 
99
  # Create QA chain with multi-query retriever
100
  qa_chain = (
101
+ {"context": multi_query_retriever, "question": RunnablePassthrough()}
 
 
 
102
  | prompt_template
103
  | llm
104
  )
 
122
 
123
  truncated_input = truncate_text(input_text)
124
 
125
+ relevant_docs = qa_chain.invoke({"query": truncated_input})
 
126
 
127
+ retrieved_knowledge = str(relevant_docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  prompt = f"""
130
  {general_task}
131
+ Genral Impression Task:
132
  {general_impression_task}
133
  Attachment Styles Task:
134
  {attachments_task}
 
146
  Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all four aspects mentioned above, however, General impressions must not be in json or dict format.
147
  Analysis:"""
148
 
149
+ #truncated_input_tokents_count = count_tokens(truncated_input)
150
+ #print('truncated_input_tokents_count:', truncated_input_tokents_count)
151
+ #input_tokens_count = count_tokens(prompt)
152
+ #print('input_tokens_count', input_tokens_count)
153
+
154
  response = llm.invoke(prompt)
155
 
156
  print("Raw LLM Model Output:")
 
171
  speaker_id = f"Speaker {i}"
172
  parsed_analysis = output_parser.parse_speaker_analysis(speaker_analysis)
173
 
174
+ # Convert general_impression to string if it's a dict or JSON object
175
  general_impression = parsed_analysis.general_impression
176
  if isinstance(general_impression, dict):
177
  general_impression = json.dumps(general_impression)
 
210
  'attachments': empty_analysis.attachment_style,
211
  'bigfive': empty_analysis.big_five_traits,
212
  'personalities': empty_analysis.personality_disorder
213
+ }}