reab5555 commited on
Commit
6afc1e5
·
verified ·
1 Parent(s): e2cdfd7

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +26 -27
processing.py CHANGED
@@ -1,12 +1,12 @@
1
  from langchain.schema import HumanMessage, BaseRetriever
2
  from output_parser import output_parser
3
- from langchain_openai import OpenAIEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model
6
  from config import openai_api_key
7
- from langchain.chains import RetrievalQA, LLMChain
8
- from langchain.retrievers import MultiQueryRetriever
9
  from langchain.prompts import PromptTemplate
 
10
  from typing import List
11
  from pydantic import Field
12
  import os
@@ -51,17 +51,10 @@ class CombinedRetriever(BaseRetriever):
51
  class Config:
52
  arbitrary_types_allowed = True
53
 
54
- def get_relevant_documents(self, query: str):
55
  combined_docs = []
56
  for retriever in self.retrievers:
57
- docs = retriever.get_relevant_documents(query)
58
- combined_docs.extend(docs)
59
- return combined_docs
60
-
61
- async def aget_relevant_documents(self, query: str):
62
- combined_docs = []
63
- for retriever in self.retrievers:
64
- docs = await retriever.aget_relevant_documents(query)
65
  combined_docs.extend(docs)
66
  return combined_docs
67
 
@@ -74,18 +67,22 @@ prompt_template = PromptTemplate(
74
  template="Generate multiple search queries for the following question: {question}"
75
  )
76
 
77
- # Create LLM chain for query generation
78
- llm_chain = LLMChain(llm=llm, prompt=prompt_template)
79
 
80
- # Initialize MultiQueryRetriever
81
- multi_query_retriever = MultiQueryRetriever(
82
- retriever=combined_retriever,
83
- llm_chain=llm_chain,
84
- parser_key="lines" # Assuming the LLM outputs queries line by line
85
- )
86
 
87
  # Create QA chain with multi-query retriever
88
- qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=multi_query_retriever)
 
 
 
 
89
 
90
  def load_text(file_path: str) -> str:
91
  with open(file_path, 'r', encoding='utf-8') as file:
@@ -107,10 +104,7 @@ def process_input(input_text: str, llm):
107
 
108
  relevant_docs = qa_chain.invoke({"query": truncated_input})
109
 
110
- if isinstance(relevant_docs, dict) and 'result' in relevant_docs:
111
- retrieved_knowledge = relevant_docs['result']
112
- else:
113
- retrieved_knowledge = str(relevant_docs)
114
 
115
  prompt = f"""{general_task}
116
  Attachment Styles Task:
@@ -129,8 +123,7 @@ Please provide a comprehensive analysis for each speaker, including:
129
  Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all four aspects mentioned above.
130
  Analysis:"""
131
 
132
- messages = [HumanMessage(content=prompt)]
133
- response = llm.invoke(messages)
134
 
135
  print("Raw LLM Model Output:")
136
  print(response.content)
@@ -176,3 +169,9 @@ Analysis:"""
176
  'bigfive': empty_analysis.big_five_traits,
177
  'personalities': empty_analysis.personality_disorder
178
  }}
 
 
 
 
 
 
 
1
  from langchain.schema import HumanMessage, BaseRetriever
2
  from output_parser import output_parser
3
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
4
  from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model
6
  from config import openai_api_key
7
+ from langchain.chains import RetrievalQA
 
8
  from langchain.prompts import PromptTemplate
9
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
10
  from typing import List
11
  from pydantic import Field
12
  import os
 
51
  class Config:
52
  arbitrary_types_allowed = True
53
 
54
+ def invoke(self, input):
55
  combined_docs = []
56
  for retriever in self.retrievers:
57
+ docs = retriever.invoke(input)
 
 
 
 
 
 
 
58
  combined_docs.extend(docs)
59
  return combined_docs
60
 
 
67
  template="Generate multiple search queries for the following question: {question}"
68
  )
69
 
70
+ # Create query generation chain
71
+ query_generation_chain = prompt_template | llm
72
 
73
+ # Create multi-query retrieval chain
74
+ def generate_queries(input):
75
+ queries = query_generation_chain.invoke({"question": input}).content.split('\n')
76
+ return [query.strip() for query in queries if query.strip()]
77
+
78
+ multi_query_retriever = RunnableLambda(generate_queries) | combined_retriever
79
 
80
  # Create QA chain with multi-query retriever
81
+ qa_chain = (
82
+ {"context": multi_query_retriever, "question": RunnablePassthrough()}
83
+ | prompt_template
84
+ | llm
85
+ )
86
 
87
  def load_text(file_path: str) -> str:
88
  with open(file_path, 'r', encoding='utf-8') as file:
 
104
 
105
  relevant_docs = qa_chain.invoke({"query": truncated_input})
106
 
107
+ retrieved_knowledge = str(relevant_docs)
 
 
 
108
 
109
  prompt = f"""{general_task}
110
  Attachment Styles Task:
 
123
  Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all four aspects mentioned above.
124
  Analysis:"""
125
 
126
+ response = llm.invoke(prompt)
 
127
 
128
  print("Raw LLM Model Output:")
129
  print(response.content)
 
169
  'bigfive': empty_analysis.big_five_traits,
170
  'personalities': empty_analysis.personality_disorder
171
  }}
172
+
173
+ # Example usage
174
+ if __name__ == "__main__":
175
+ input_text = "Your input text here"
176
+ result = process_input(input_text, llm)
177
+ print(json.dumps(result, indent=2))