Spaces:
Sleeping
Sleeping
File size: 2,537 Bytes
eefbdd1 ca4603b eefbdd1 ca4603b eefbdd1 ca4603b eefbdd1 ca4603b eefbdd1 ca4603b eefbdd1 ca4603b eefbdd1 ca4603b eefbdd1 ca4603b eefbdd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import json
import logging
from typing import Dict
from langchain import PromptTemplate, LLMChain
from models import chat_model
logger = logging.getLogger(__name__)
# Updated prompt template to include eight new themes
problem_prompt_template = PromptTemplate(
input_variables=["responses", "internal_report"],
template=(
"You are a wellness analyst. You have the following user responses to health-related questions:\n"
"{responses}\n\n"
"You also have an internal analysis report:\n"
"{internal_report}\n\n"
"From these inputs, determine a 'problem severity percentage' for the user in the following areas: "
"stress_management, low_therapy, balanced_weight, restless_night, lack_of_motivation, gut_health, anxiety, burnout. "
"Return your answer in JSON format with keys: stress_management, low_therapy, balanced_weight, restless_night, "
"lack_of_motivation, gut_health, anxiety, burnout.\n"
"Ensure severity percentages are numbers from 0 to 100.\n\n"
"JSON Output:"
)
)
problem_chain = LLMChain(llm=chat_model, prompt=problem_prompt_template)
def analyze_problems_with_chain(responses: Dict[str, str], internal_report: str) -> Dict[str, float]:
responses_str = "\n".join(f"{q}: {a}" for q, a in responses.items())
raw_text = problem_chain.run(responses=responses_str, internal_report=internal_report)
try:
# Extract JSON from the LLM output
start_idx = raw_text.find('{')
end_idx = raw_text.rfind('}') + 1
json_str = raw_text[start_idx:end_idx]
problems = json.loads(json_str)
# Ensure all eight keys are present with default values
for key in [
"stress_management",
"low_therapy",
"balanced_weight",
"restless_night",
"lack_of_motivation",
"gut_health",
"anxiety",
"burnout"
]:
problems.setdefault(key, 0.0)
return {k: float(v) for k, v in problems.items()}
except Exception as e:
logger.error(f"Error parsing problem percentages from LLM: {e}")
# Return default values for all eight themes in case of an error
return {
"stress_management": 0.0,
"low_therapy": 0.0,
"balanced_weight": 0.0,
"restless_night": 0.0,
"lack_of_motivation": 0.0,
"gut_health": 0.0,
"anxiety": 0.0,
"burnout": 0.0
}
|