reab5555 commited on
Commit
e6355c1
·
verified ·
1 Parent(s): 5074417

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +89 -11
processing.py CHANGED
@@ -1,32 +1,108 @@
1
- from langchain.schema import HumanMessage
2
  from output_parser import output_parser
3
- from langchain_openai import OpenAIEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model
6
  from config import openai_api_key
7
  from langchain.chains import RetrievalQA
 
 
 
 
 
8
  import os
9
  import json
10
 
 
11
  embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
12
 
 
13
  knowledge_files = {
14
  "attachments": "knowledge/bartholomew_attachments_definitions.txt",
15
  "bigfive": "knowledge/bigfive_definitions.txt",
16
  "personalities": "knowledge/personalities_definitions.txt"
17
  }
18
 
 
19
  documents = []
20
  for key, file_path in knowledge_files.items():
21
  with open(file_path, 'r', encoding='utf-8') as file:
22
  content = file.read().strip()
23
  documents.append(content)
24
 
25
- faiss_index = FAISS.from_texts(documents, embedding_model)
 
26
 
 
 
 
 
 
27
  llm = load_model(openai_api_key)
28
 
29
- qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=faiss_index.as_retriever())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def load_text(file_path: str) -> str:
32
  with open(file_path, 'r', encoding='utf-8') as file:
@@ -48,10 +124,7 @@ def process_input(input_text: str, llm):
48
 
49
  relevant_docs = qa_chain.invoke({"query": truncated_input})
50
 
51
- if isinstance(relevant_docs, dict) and 'result' in relevant_docs:
52
- retrieved_knowledge = relevant_docs['result']
53
- else:
54
- retrieved_knowledge = str(relevant_docs)
55
 
56
  prompt = f"""{general_task}
57
  Attachment Styles Task:
@@ -70,8 +143,7 @@ Please provide a comprehensive analysis for each speaker, including:
70
  Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all four aspects mentioned above.
71
  Analysis:"""
72
 
73
- messages = [HumanMessage(content=prompt)]
74
- response = llm.invoke(messages)
75
 
76
  print("Raw LLM Model Output:")
77
  print(response.content)
@@ -116,4 +188,10 @@ Analysis:"""
116
  'attachments': empty_analysis.attachment_style,
117
  'bigfive': empty_analysis.big_five_traits,
118
  'personalities': empty_analysis.personality_disorder
119
- }}
 
 
 
 
 
 
 
1
+ from langchain.schema import HumanMessage, BaseRetriever, Document
2
  from output_parser import output_parser
3
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
4
  from langchain_community.vectorstores import FAISS
5
  from llm_loader import load_model
6
  from config import openai_api_key
7
  from langchain.chains import RetrievalQA
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
10
+ from typing import List, Any, Optional
11
+ from pydantic import Field
12
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
13
  import os
14
  import json
15
 
16
+ # Initialize embedding model
17
  embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
18
 
19
+ # Define knowledge files
20
  knowledge_files = {
21
  "attachments": "knowledge/bartholomew_attachments_definitions.txt",
22
  "bigfive": "knowledge/bigfive_definitions.txt",
23
  "personalities": "knowledge/personalities_definitions.txt"
24
  }
25
 
26
+ # Load text-based knowledge
27
  documents = []
28
  for key, file_path in knowledge_files.items():
29
  with open(file_path, 'r', encoding='utf-8') as file:
30
  content = file.read().strip()
31
  documents.append(content)
32
 
33
+ # Create FAISS index from text documents
34
+ text_faiss_index = FAISS.from_texts(documents, embedding_model)
35
 
36
+ # Load pre-existing FAISS indexes
37
+ attachments_faiss_index = FAISS.load_local("knowledge/faiss_index_Attachments_db", embedding_model, allow_dangerous_deserialization=True)
38
+ personalities_faiss_index = FAISS.load_local("knowledge/faiss_index_Personalities_db", embedding_model, allow_dangerous_deserialization=True)
39
+
40
+ # Initialize LLM
41
  llm = load_model(openai_api_key)
42
 
43
+ # Create retrievers for each index
44
+ text_retriever = text_faiss_index.as_retriever()
45
+ attachments_retriever = attachments_faiss_index.as_retriever()
46
+ personalities_retriever = personalities_faiss_index.as_retriever()
47
+
48
+ class CombinedRetriever(BaseRetriever):
49
+ retrievers: List[BaseRetriever] = Field(default_factory=list)
50
+
51
+ class Config:
52
+ arbitrary_types_allowed = True
53
+
54
+ def _get_relevant_documents(
55
+ self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None
56
+ ) -> List[Document]:
57
+ combined_docs = []
58
+ for retriever in self.retrievers:
59
+ docs = retriever.get_relevant_documents(query, run_manager=run_manager)
60
+ combined_docs.extend(docs)
61
+ return combined_docs
62
+
63
+ async def _aget_relevant_documents(
64
+ self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None
65
+ ) -> List[Document]:
66
+ combined_docs = []
67
+ for retriever in self.retrievers:
68
+ docs = await retriever.aget_relevant_documents(query, run_manager=run_manager)
69
+ combined_docs.extend(docs)
70
+ return combined_docs
71
+
72
+ # Create an instance of the combined retriever
73
+ combined_retriever = CombinedRetriever(retrievers=[text_retriever, attachments_retriever, personalities_retriever])
74
+
75
+ # Create prompt template for query generation
76
+ prompt_template = PromptTemplate(
77
+ input_variables=["question"],
78
+ template="Generate multiple search queries for the following question: {question}"
79
+ )
80
+
81
+ # Create query generation chain
82
+ query_generation_chain = prompt_template | llm
83
+
84
+ # Create multi-query retrieval chain
85
+ def generate_queries(input):
86
+ queries = query_generation_chain.invoke({"question": input}).content.split('\n')
87
+ return [query.strip() for query in queries if query.strip()]
88
+
89
+ def multi_query_retrieve(input):
90
+ queries = generate_queries(input)
91
+ all_docs = []
92
+ for query in queries:
93
+ docs = combined_retriever.get_relevant_documents(query)
94
+ all_docs.extend(docs)
95
+ return all_docs
96
+
97
+ multi_query_retriever = RunnableLambda(multi_query_retrieve)
98
+
99
+ # Create QA chain with multi-query retriever
100
+ qa_chain = (
101
+ {"context": multi_query_retriever, "question": RunnablePassthrough()}
102
+ | prompt_template
103
+ | llm
104
+ )
105
+
106
 
107
  def load_text(file_path: str) -> str:
108
  with open(file_path, 'r', encoding='utf-8') as file:
 
124
 
125
  relevant_docs = qa_chain.invoke({"query": truncated_input})
126
 
127
+ retrieved_knowledge = str(relevant_docs)
 
 
 
128
 
129
  prompt = f"""{general_task}
130
  Attachment Styles Task:
 
143
  Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all four aspects mentioned above.
144
  Analysis:"""
145
 
146
+ response = llm.invoke(prompt)
 
147
 
148
  print("Raw LLM Model Output:")
149
  print(response.content)
 
188
  'attachments': empty_analysis.attachment_style,
189
  'bigfive': empty_analysis.big_five_traits,
190
  'personalities': empty_analysis.personality_disorder
191
+ }}
192
+
193
+ # Example usage
194
+ if __name__ == "__main__":
195
+ input_text = "Your input text here"
196
+ result = process_input(input_text, llm)
197
+ print(json.dumps(result, indent=2))