Spaces:
Runtime error
Runtime error
from langchain.schema import HumanMessage, BaseRetriever, Document | |
from output_parser import output_parser | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from langchain_community.vectorstores import FAISS | |
from llm_loader import load_model, count_tokens | |
from config import openai_api_key | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
from typing import List, Any, Optional | |
from pydantic import Field | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
import os | |
import json | |
# Initialize embedding model | |
embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
# Define knowledge files | |
knowledge_files = { | |
"attachments": "knowledge/bartholomew_attachments_definitions.txt", | |
"bigfive": "knowledge/bigfive_definitions.txt", | |
"personalities": "knowledge/personalities_definitions.txt" | |
} | |
# Load text-based knowledge | |
documents = [] | |
for key, file_path in knowledge_files.items(): | |
with open(file_path, 'r', encoding='utf-8') as file: | |
content = file.read().strip() | |
documents.append(content) | |
# Create FAISS index from text documents | |
text_faiss_index = FAISS.from_texts(documents, embedding_model) | |
# Load pre-existing FAISS indexes | |
attachments_faiss_index = FAISS.load_local("knowledge/faiss_index_Attachments_db", embedding_model, allow_dangerous_deserialization=True) | |
personalities_faiss_index = FAISS.load_local("knowledge/faiss_index_Personalities_db", embedding_model, allow_dangerous_deserialization=True) | |
# Initialize LLM | |
llm = load_model(openai_api_key) | |
# Create retrievers for each index | |
text_retriever = text_faiss_index.as_retriever() | |
attachments_retriever = attachments_faiss_index.as_retriever() | |
personalities_retriever = personalities_faiss_index.as_retriever() | |
class CombinedRetriever(BaseRetriever): | |
retrievers: List[BaseRetriever] = Field(default_factory=list) | |
class Config: | |
arbitrary_types_allowed = True | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None | |
) -> List[Document]: | |
combined_docs = [] | |
for retriever in self.retrievers: | |
docs = retriever.get_relevant_documents(query, run_manager=run_manager) | |
combined_docs.extend(docs) | |
return combined_docs | |
async def _aget_relevant_documents( | |
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None | |
) -> List[Document]: | |
combined_docs = [] | |
for retriever in self.retrievers: | |
docs = await retriever.aget_relevant_documents(query, run_manager=run_manager) | |
combined_docs.extend(docs) | |
return combined_docs | |
# Create an instance of the combined retriever | |
combined_retriever = CombinedRetriever(retrievers=[text_retriever, attachments_retriever, personalities_retriever]) | |
# Create prompt template for query generation | |
prompt_template = PromptTemplate( | |
input_variables=["question"], | |
template="Generate multiple search queries for the following question: {question}" | |
) | |
# Create query generation chain | |
query_generation_chain = prompt_template | llm | |
# Create multi-query retrieval chain | |
def generate_queries(input): | |
queries = query_generation_chain.invoke({"question": input}).content.split('\n') | |
return [query.strip() for query in queries if query.strip()] | |
def multi_query_retrieve(input): | |
queries = generate_queries(input) | |
all_docs = [] | |
for query in queries: | |
docs = combined_retriever.get_relevant_documents(query) | |
all_docs.extend(docs) | |
return all_docs | |
multi_query_retriever = RunnableLambda(multi_query_retrieve) | |
# Create QA chain with multi-query retriever | |
qa_chain = ( | |
{"context": multi_query_retriever, "question": RunnablePassthrough()} | |
| prompt_template | |
| llm | |
) | |
def load_text(file_path: str) -> str: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
return file.read().strip() | |
def truncate_text(text: str, max_tokens: int = 16000) -> str: | |
words = text.split() | |
if len(words) > max_tokens: | |
return ' '.join(words[:max_tokens]) | |
return text | |
def process_input(input_text: str, llm): | |
general_task = load_text("tasks/General_tasks_description.txt") | |
general_impression_task = load_text("tasks/General_Impression_task.txt") | |
attachments_task = load_text("tasks/Attachments_task.txt") | |
bigfive_task = load_text("tasks/BigFive_task.txt") | |
personalities_task = load_text("tasks/Personalities_task.txt") | |
truncated_input = truncate_text(input_text) | |
relevant_docs = qa_chain.invoke({"query": truncated_input}) | |
retrieved_knowledge = str(relevant_docs) | |
prompt = f""" | |
{general_task} | |
Genral Impression Task: | |
{general_impression_task} | |
Attachment Styles Task: | |
{attachments_task} | |
Big Five Traits Task: | |
{bigfive_task} | |
Personality Disorders Task: | |
{personalities_task} | |
Retrieved Knowledge: {retrieved_knowledge} | |
Input: {truncated_input} | |
Please provide a comprehensive analysis for each speaker, including: | |
1. General impressions (answer the sections provided in the General Impression Task.) | |
2. Attachment styles (use the format from the Attachment Styles Task) | |
3. Big Five traits (use the format from the Big Five Traits Task) | |
4. Personality disorders (use the format from the Personality Disorders Task) | |
Respond with a JSON object containing an array of speaker analyses under the key 'speaker_analyses'. Each speaker analysis should include all four aspects mentioned above, however, General impressions must not be in json or dict format. | |
Analysis:""" | |
#truncated_input_tokents_count = count_tokens(truncated_input) | |
#print('truncated_input_tokents_count:', truncated_input_tokents_count) | |
#input_tokens_count = count_tokens(prompt) | |
#print('input_tokens_count', input_tokens_count) | |
response = llm.invoke(prompt) | |
print("Raw LLM Model Output:") | |
print(response.content) | |
try: | |
content = response.content | |
if content.startswith("```json"): | |
content = content.split("```json", 1)[1] | |
if content.endswith("```"): | |
content = content.rsplit("```", 1)[0] | |
parsed_json = json.loads(content.strip()) | |
results = {} | |
speaker_analyses = parsed_json.get('speaker_analyses', []) | |
for i, speaker_analysis in enumerate(speaker_analyses, 1): | |
speaker_id = f"Speaker {i}" | |
parsed_analysis = output_parser.parse_speaker_analysis(speaker_analysis) | |
# Convert general_impression to string if it's a dict or JSON object | |
general_impression = parsed_analysis.general_impression | |
if isinstance(general_impression, dict): | |
general_impression = json.dumps(general_impression) | |
elif isinstance(general_impression, str): | |
try: | |
# Check if it's a JSON string | |
json.loads(general_impression) | |
# If it parses successfully, it's likely a JSON string, so we'll keep it as is | |
except json.JSONDecodeError: | |
# If it's not a valid JSON string, we'll keep it as is (it's already a string) | |
pass | |
results[speaker_id] = { | |
'general_impression': general_impression, | |
'attachments': parsed_analysis.attachment_style, | |
'bigfive': parsed_analysis.big_five_traits, | |
'personalities': parsed_analysis.personality_disorder | |
} | |
if not results: | |
print("Warning: No speaker analyses found in the parsed JSON.") | |
empty_analysis = output_parser.parse_speaker_analysis({}) | |
return {"Speaker 1": { | |
'general_impression': empty_analysis.general_impression, | |
'attachments': empty_analysis.attachment_style, | |
'bigfive': empty_analysis.big_five_traits, | |
'personalities': empty_analysis.personality_disorder | |
}} | |
return results | |
except Exception as e: | |
print(f"Error processing input: {e}") | |
empty_analysis = output_parser.parse_speaker_analysis({}) | |
return {"Speaker 1": { | |
'general_impression': empty_analysis.general_impression, | |
'attachments': empty_analysis.attachment_style, | |
'bigfive': empty_analysis.big_five_traits, | |
'personalities': empty_analysis.personality_disorder | |
}} |