Spaces:
Runtime error
Runtime error
File size: 8,556 Bytes
e6355c1 3524dd9 e6355c1 2355280 9664404 6724fb5 6afc1e5 e6355c1 e913097 e6355c1 b88e12c baa51c6 b88e12c e6355c1 baa51c6 b88e12c e6355c1 b88e12c e6355c1 b88e12c e6355c1 b88e12c e6355c1 f497377 e6355c1 baa51c6 e3551a8 e6355c1 c07391a e6355c1 c07391a e6355c1 c07391a e6355c1 e913097 e6355c1 c07391a 9720614 e6355c1 e3551a8 1a67fa6 3b9f50e 9763663 e3551a8 8948925 e3551a8 cdb5785 e7f06b1 a0cd8f0 cdb5785 e3551a8 c07391a baa51c6 c07391a b88e12c 84b171f c07391a a0cd8f0 cdb5785 b88e12c e3551a8 cdb5785 a0cd8f0 3524dd9 eb58569 e3551a8 c07391a e6355c1 9796cc7 8948925 1a67fa6 0a8dc2e 6724fb5 cdb5785 f33ef48 71b710d 50418db 3050414 c07391a 3050414 8948925 7fc8026 46371a3 cdb5785 f33ef48 50418db 71b710d 46371a3 f33ef48 cdb5785 e3551a8 8948925 50418db 71b710d 46371a3 c07391a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
from langchain.schema import HumanMessage, BaseRetriever, Document
from output_parser import output_parser
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import FAISS
from llm_loader import load_model, count_tokens
from config import openai_api_key
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from typing import List, Any, Optional
from pydantic import Field
from langchain_core.callbacks import CallbackManagerForRetrieverRun
import os
import json
# Initialize embedding model
embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
# Define knowledge files
knowledge_files = {
"attachments": "knowledge/bartholomew_attachments_definitions.txt",
"bigfive": "knowledge/bigfive_definitions.txt",
"personalities": "knowledge/personalities_definitions.txt"
}
# Load text-based knowledge
documents = []
for key, file_path in knowledge_files.items():
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read().strip()
documents.append(content)
# Create FAISS index from text documents
text_faiss_index = FAISS.from_texts(documents, embedding_model)
# Load pre-existing FAISS indexes
attachments_faiss_index = FAISS.load_local("knowledge/faiss_index_Attachments_db", embedding_model, allow_dangerous_deserialization=True)
personalities_faiss_index = FAISS.load_local("knowledge/faiss_index_Personalities_db", embedding_model, allow_dangerous_deserialization=True)
# Initialize LLM
llm = load_model(openai_api_key)
# Create retrievers for each index
text_retriever = text_faiss_index.as_retriever()
attachments_retriever = attachments_faiss_index.as_retriever()
personalities_retriever = personalities_faiss_index.as_retriever()
class CombinedRetriever(BaseRetriever):
retrievers: List[BaseRetriever] = Field(default_factory=list)
class Config:
arbitrary_types_allowed = True
def _get_relevant_documents(
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None
) -> List[Document]:
combined_docs = []
for retriever in self.retrievers:
docs = retriever.get_relevant_documents(query, run_manager=run_manager)
combined_docs.extend(docs)
return combined_docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None
) -> List[Document]:
combined_docs = []
for retriever in self.retrievers:
docs = await retriever.aget_relevant_documents(query, run_manager=run_manager)
combined_docs.extend(docs)
return combined_docs
# Create an instance of the combined retriever
combined_retriever = CombinedRetriever(retrievers=[text_retriever, attachments_retriever, personalities_retriever])
# Create prompt template for query generation
prompt_template = PromptTemplate(
input_variables=["question"],
template="Generate multiple search queries for the following question: {question}"
)
# Create query generation chain
query_generation_chain = prompt_template | llm
# Create multi-query retrieval chain
def generate_queries(input):
queries = query_generation_chain.invoke({"question": input}).content.split('\n')
return [query.strip() for query in queries if query.strip()]
def multi_query_retrieve(input):
queries = generate_queries(input)
all_docs = []
for query in queries:
docs = combined_retriever.get_relevant_documents(query)
all_docs.extend(docs)
return all_docs
multi_query_retriever = RunnableLambda(multi_query_retrieve)
# Create QA chain with multi-query retriever
qa_chain = (
{"context": multi_query_retriever, "question": RunnablePassthrough()}
| prompt_template
| llm
)
def load_text(file_path: str) -> str:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read().strip()
def truncate_text(text: str, max_tokens: int = 16000) -> str:
words = text.split()
if len(words) > max_tokens:
return ' '.join(words[:max_tokens])
return text
def process_input(input_text: str, llm):
general_task = load_text("tasks/General_tasks_description.txt")
general_impression_task = load_text("tasks/General_Impression_task.txt")
attachments_task = load_text("tasks/Attachments_task.txt")
bigfive_task = load_text("tasks/BigFive_task.txt")
personalities_task = load_text("tasks/Personalities_task.txt")
truncated_input = truncate_text(input_text)
relevant_docs = qa_chain.invoke({"query": truncated_input})
retrieved_knowledge = str(relevant_docs)
prompt = f"""
{general_task}
Genral Impression Task:
{general_impression_task}
Attachment Styles Task:
{attachments_task}
Big Five Traits Task:
{bigfive_task}
Personality Disorders Task:
{personalities_task}
Retrieved Knowledge: {retrieved_knowledge}
Input: {truncated_input}
Please provide a comprehensive analysis for each speaker, including:
1. General impressions (answer the sections provided in the General Impression Task.)
2. Attachment styles (use the format from the Attachment Styles Task)
3. Big Five traits (use the format from the Big Five Traits Task)
4. Personality disorders (use the format from the Personality Disorders Task)
Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all four aspects mentioned above, however, General impressions must not be in json or dict format.
Analysis:"""
#truncated_input_tokents_count = count_tokens(truncated_input)
#print('truncated_input_tokents_count:', truncated_input_tokents_count)
#input_tokens_count = count_tokens(prompt)
#print('input_tokens_count', input_tokens_count)
response = llm.invoke(prompt)
print("Raw LLM Model Output:")
print(response.content)
try:
content = response.content
if content.startswith("```json"):
content = content.split("```json", 1)[1]
if content.endswith("```"):
content = content.rsplit("```", 1)[0]
parsed_json = json.loads(content.strip())
results = {}
speaker_analyses = parsed_json.get('speaker_analyses', [])
for i, speaker_analysis in enumerate(speaker_analyses, 1):
speaker_id = f"Speaker {i}"
parsed_analysis = output_parser.parse_speaker_analysis(speaker_analysis)
# Convert general_impression to string if it's a dict or JSON object
general_impression = parsed_analysis.general_impression
if isinstance(general_impression, dict):
general_impression = json.dumps(general_impression)
elif isinstance(general_impression, str):
try:
# Check if it's a JSON string
json.loads(general_impression)
# If it parses successfully, it's likely a JSON string, so we'll keep it as is
except json.JSONDecodeError:
# If it's not a valid JSON string, we'll keep it as is (it's already a string)
pass
results[speaker_id] = {
'general_impression': general_impression,
'attachments': parsed_analysis.attachment_style,
'bigfive': parsed_analysis.big_five_traits,
'personalities': parsed_analysis.personality_disorder
}
if not results:
print("Warning: No speaker analyses found in the parsed JSON.")
empty_analysis = output_parser.parse_speaker_analysis({})
return {"Speaker 1": {
'general_impression': empty_analysis.general_impression,
'attachments': empty_analysis.attachment_style,
'bigfive': empty_analysis.big_five_traits,
'personalities': empty_analysis.personality_disorder
}}
return results
except Exception as e:
print(f"Error processing input: {e}")
empty_analysis = output_parser.parse_speaker_analysis({})
return {"Speaker 1": {
'general_impression': empty_analysis.general_impression,
'attachments': empty_analysis.attachment_style,
'bigfive': empty_analysis.big_five_traits,
'personalities': empty_analysis.personality_disorder
}} |