reab5555's picture
Update output_parser.py
8df2e34 verified
raw
history blame
4.9 kB
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain.prompts import PromptTemplate
from pydantic import BaseModel
from typing import Dict
class AttachmentStyle(BaseModel):
speaker: str
secured: float
anxious_preoccupied: float
dismissive_avoidant: float
fearful_avoidant: float
self_rating: int
others_rating: int
anxiety: int
avoidance: int
explanation: str
class BigFiveTraits(BaseModel):
speaker: str
extraversion: int
agreeableness: int
conscientiousness: int
neuroticism: int
openness: int
explanation: str
class PersonalityDisorder(BaseModel):
speaker: str
depressed: int
paranoid: int
schizoid_schizotypal: int
antisocial_psychopathic: int
borderline_dysregulated: int
narcissistic: int
anxious_avoidant: int
dependent_victimized: int
obsessional: int
explanation: str
attachment_response_schemas = [
ResponseSchema(name="speaker", description="The name or number of the speaker"),
ResponseSchema(name="secured", description="Probability of secured attachment style (0-1)"),
ResponseSchema(name="anxious_preoccupied", description="Probability of anxious-preoccupied attachment style (0-1)"),
ResponseSchema(name="dismissive_avoidant", description="Probability of dismissive-avoidant attachment style (0-1)"),
ResponseSchema(name="fearful_avoidant", description="Probability of fearful-avoidant attachment style (0-1)"),
ResponseSchema(name="self_rating", description="Self rating (0-10)"),
ResponseSchema(name="others_rating", description="Others rating (0-10)"),
ResponseSchema(name="anxiety", description="Anxiety rating (0-10)"),
ResponseSchema(name="avoidance", description="Avoidance rating (0-10)"),
ResponseSchema(name="explanation", description="Brief explanation of the attachment style")
]
bigfive_response_schemas = [
ResponseSchema(name="speaker", description="The name or number of the speaker"),
ResponseSchema(name="extraversion", description="Extraversion rating (-10 to 10)"),
ResponseSchema(name="agreeableness", description="Agreeableness rating (-10 to 10)"),
ResponseSchema(name="conscientiousness", description="Conscientiousness rating (-10 to 10)"),
ResponseSchema(name="neuroticism", description="Neuroticism rating (-10 to 10)"),
ResponseSchema(name="openness", description="Openness rating (-10 to 10)"),
ResponseSchema(name="explanation", description="Brief explanation of the Big Five traits")
]
personality_response_schemas = [
ResponseSchema(name="speaker", description="The name or number of the speaker"),
ResponseSchema(name="depressed", description="Depressed rating (0-4)"),
ResponseSchema(name="paranoid", description="Paranoid rating (0-4)"),
ResponseSchema(name="schizoid_schizotypal", description="Schizoid-Schizotypal rating (0-4)"),
ResponseSchema(name="antisocial_psychopathic", description="Antisocial-Psychopathic rating (0-4)"),
ResponseSchema(name="borderline_dysregulated", description="Borderline-Dysregulated rating (0-4)"),
ResponseSchema(name="narcissistic", description="Narcissistic rating (0-4)"),
ResponseSchema(name="anxious_avoidant", description="Anxious-Avoidant rating (0-4)"),
ResponseSchema(name="dependent_victimized", description="Dependent-Victimized rating (0-4)"),
ResponseSchema(name="obsessional", description="Obsessional rating (0-4)"),
ResponseSchema(name="explanation", description="Brief explanation of the personality disorders")
]
attachment_parser = StructuredOutputParser.from_response_schemas(attachment_response_schemas)
bigfive_parser = StructuredOutputParser.from_response_schemas(bigfive_response_schemas)
personality_parser = StructuredOutputParser.from_response_schemas(personality_response_schemas)
def get_prompt_template(task: str, parser: StructuredOutputParser) -> PromptTemplate:
return PromptTemplate(
template="Analyze the following text according to the given task:\n\n{task}\n\n{format_instructions}\n\nText: {text}\n\nAnalysis:",
input_variables=["text"],
partial_variables={
"task": task,
"format_instructions": parser.get_format_instructions()
}
)
def parse_analysis_output(output: str, analysis_type: str) -> Dict[str, BaseModel]:
if analysis_type == "attachments":
parsed = attachment_parser.parse(output)
return {parsed['speaker']: AttachmentStyle(**parsed)}
elif analysis_type == "bigfive":
parsed = bigfive_parser.parse(output)
return {parsed['speaker']: BigFiveTraits(**parsed)}
elif analysis_type == "personalities":
parsed = personality_parser.parse(output)
return {parsed['speaker']: PersonalityDisorder(**parsed)}
else:
raise ValueError(f"Unknown analysis type: {analysis_type}")