Spaces:
Runtime error
Runtime error
File size: 12,850 Bytes
ea1e6bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
"""
Utility functions for query processing and rewriting.
"""
import time
import logging
from openai import OpenAI
from prompt_template import (
Prompt_template_translation,
Prompt_template_relevance,
Prompt_template_autism_confidence,
Prompt_template_autism_rewriter,
Prompt_template_answer_autism_relevance
)
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize OpenAI client
DEEPINFRA_API_KEY = "285LUJulGIprqT6hcPhiXtcrphU04FG4"
openai = OpenAI(
api_key=DEEPINFRA_API_KEY,
base_url="https://api.deepinfra.com/v1/openai",
)
def call_llm(model: str, messages: list[dict], temperature: float = 0.0, timeout: int = 30, **kwargs) -> str:
"""Call the LLM with given messages and return the response."""
try:
logger.info(f"Making API call to {model} with timeout {timeout}s")
start_time = time.time()
resp = openai.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
timeout=timeout,
**kwargs
)
elapsed = time.time() - start_time
logger.info(f"API call completed in {elapsed:.2f}s")
return resp.choices[0].message.content.strip()
except Exception as e:
logger.error(f"API call failed: {e}")
# Return fallback response
if "translation" in str(messages).lower():
# For translation, return the original query
return messages[0]["content"].split("Query: ")[-1] if "Query: " in messages[0]["content"] else "Error"
else:
# For relevance, assume not related
return "0"
def enhanced_autism_relevance_check(query: str) -> dict:
"""
Enhanced autism relevance checking with detailed analysis.
Returns a dictionary with score, category, and reasoning.
"""
try:
logger.info(f"Enhanced autism relevance check for: '{query[:50]}...'")
# Use the enhanced confidence prompt
confidence_prompt = Prompt_template_autism_confidence.format(query=query)
response = call_llm(
model="Qwen/Qwen3-32B",
messages=[{"role": "user", "content": confidence_prompt}],
reasoning_effort="none",
timeout=15
)
# Extract numeric score
confidence_score = 0
try:
import re
numbers = re.findall(r'\d+', response)
if numbers:
confidence_score = int(numbers[0])
confidence_score = max(0, min(100, confidence_score))
except:
confidence_score = 0
# Determine category and action based on enhanced scoring
if confidence_score >= 85:
category = "directly_autism_related"
action = "accept_as_is"
reasoning = "Directly mentions autism or autism-specific topics"
elif confidence_score >= 70:
category = "highly_autism_relevant"
action = "accept_as_is"
reasoning = "Core autism symptoms or characteristics"
elif confidence_score >= 55:
category = "significantly_autism_relevant"
action = "rewrite_for_autism"
reasoning = "Common comorbidity or autism-related issue"
elif confidence_score >= 40:
category = "moderately_autism_relevant"
action = "rewrite_for_autism"
reasoning = "Broader developmental or family concern related to autism"
elif confidence_score >= 25:
category = "somewhat_autism_relevant"
action = "conditional_rewrite"
reasoning = "General topic with potential autism applications"
else:
category = "not_autism_relevant"
action = "reject"
reasoning = "Not related to autism or autism care"
result = {
"score": confidence_score,
"category": category,
"action": action,
"reasoning": reasoning
}
logger.info(f"Enhanced relevance result: {result}")
return result
except Exception as e:
logger.error(f"Error in enhanced_autism_relevance_check: {e}")
return {
"score": 0,
"category": "error",
"action": "reject",
"reasoning": "Error during processing"
}
def check_autism_confidence(query: str) -> int:
"""
Check autism relevance confidence score (0-100).
Returns the confidence score as an integer.
"""
try:
logger.info(f"Checking autism confidence for query: '{query[:50]}...'")
confidence_prompt = Prompt_template_autism_confidence.format(query=query)
response = call_llm(
model="Qwen/Qwen3-32B",
messages=[{"role": "user", "content": confidence_prompt}],
reasoning_effort="none",
timeout=15
)
# Extract numeric score from response
confidence_score = 0
try:
# Try to extract number from response
import re
numbers = re.findall(r'\d+', response)
if numbers:
confidence_score = int(numbers[0])
# Ensure it's within valid range
confidence_score = max(0, min(100, confidence_score))
else:
logger.warning(f"No numeric score found in response: {response}")
confidence_score = 0
except:
logger.error(f"Failed to parse confidence score from: {response}")
confidence_score = 0
logger.info(f"Autism confidence score: {confidence_score}")
return confidence_score
except Exception as e:
logger.error(f"Error in check_autism_confidence: {e}")
return 0
def rewrite_query_for_autism(query: str) -> str:
"""
Automatically rewrite a query to be autism-specific.
"""
try:
logger.info(f"Rewriting query for autism: '{query[:50]}...'")
rewrite_prompt = Prompt_template_autism_rewriter.format(query=query)
rewritten_query = call_llm(
model="Qwen/Qwen3-32B",
messages=[{"role": "user", "content": rewrite_prompt}],
reasoning_effort="none",
timeout=15
)
if rewritten_query == "Error" or len(rewritten_query.strip()) == 0:
logger.warning("Rewriting failed, using fallback")
rewritten_query = f"How does autism relate to {query.lower()}?"
else:
rewritten_query = rewritten_query.strip()
logger.info(f"Query rewritten to: '{rewritten_query[:50]}...'")
return rewritten_query
except Exception as e:
logger.error(f"Error in rewrite_query_for_autism: {e}")
return f"How does autism relate to {query.lower()}?"
def check_answer_autism_relevance(answer: str) -> int:
"""
Check if an answer is sufficiently related to autism (0-100 score).
Used for document-based queries to filter non-autism answers.
"""
try:
logger.info(f"Checking answer autism relevance for: '{answer[:50]}...'")
relevance_prompt = Prompt_template_answer_autism_relevance.format(answer=answer)
response = call_llm(
model="Qwen/Qwen3-32B",
messages=[{"role": "user", "content": relevance_prompt}],
reasoning_effort="none",
timeout=15
)
# Extract numeric score from response
relevance_score = 0
try:
import re
numbers = re.findall(r'\d+', response)
if numbers:
relevance_score = int(numbers[0])
relevance_score = max(0, min(100, relevance_score))
else:
logger.warning(f"No numeric score found in response: {response}")
relevance_score = 0
except:
logger.error(f"Failed to parse relevance score from: {response}")
relevance_score = 0
logger.info(f"Answer autism relevance score: {relevance_score}")
return relevance_score
except Exception as e:
logger.error(f"Error in check_answer_autism_relevance: {e}")
return 0
def process_query_for_rewrite(query: str) -> tuple[str, bool, str]:
"""
Enhanced query processing with sophisticated autism relevance detection.
NEW ENHANCED LOGIC:
1. Score 85-100 → Directly autism-related, use as-is
2. Score 70-84 → Highly autism-relevant (core symptoms), use as-is
3. Score 55-69 → Significantly autism-relevant (comorbidities), rewrite for autism
4. Score 40-54 → Moderately autism-relevant, rewrite for autism
5. Score 25-39 → Somewhat relevant, conditional rewrite (ask user or auto-rewrite)
6. Score 0-24 → Not autism-related, reject
Returns: (processed_query, is_autism_related, rewritten_query_if_needed)
"""
try:
logger.info(f"Processing query with enhanced confidence logic: '{query[:50]}...'")
start_time = time.time()
# Step 1: Translate and correct the query
logger.info("Step 1: Translating/correcting query")
corrected_query = call_llm(
model="Qwen/Qwen3-32B",
messages=[{"role": "user", "content": Prompt_template_translation.format(query=query)}],
reasoning_effort="none",
timeout=15
)
if corrected_query == "Error":
logger.warning("Translation failed, using original query")
corrected_query = query
# Step 2: Get enhanced autism relevance analysis
logger.info("Step 2: Enhanced autism relevance checking")
relevance_result = enhanced_autism_relevance_check(corrected_query)
confidence_score = relevance_result["score"]
action = relevance_result["action"]
reasoning = relevance_result["reasoning"]
logger.info(f"Relevance analysis: {confidence_score}% - {reasoning}")
# Step 3: Take action based on enhanced analysis
if action == "accept_as_is":
logger.info(f"High relevance ({confidence_score}%) - accepting as-is: {reasoning}")
return corrected_query, True, ""
elif action == "rewrite_for_autism":
logger.info(f"Moderate relevance ({confidence_score}%) - rewriting for autism: {reasoning}")
rewritten_query = rewrite_query_for_autism(corrected_query)
return rewritten_query, True, ""
elif action == "conditional_rewrite":
# For somewhat relevant queries, automatically rewrite (could be enhanced with user confirmation)
logger.info(f"Low-moderate relevance ({confidence_score}%) - conditionally rewriting: {reasoning}")
rewritten_query = rewrite_query_for_autism(corrected_query)
return rewritten_query, True, ""
else: # action == "reject"
logger.info(f"Low relevance ({confidence_score}%) - rejecting: {reasoning}")
return corrected_query, False, ""
elapsed = time.time() - start_time
logger.info(f"Enhanced query processing completed in {elapsed:.2f}s")
except Exception as e:
logger.error(f"Error in process_query_for_rewrite: {e}")
# Fallback: return original query as not autism-related
return query, False, ""
def get_non_autism_response() -> str:
"""Return a more human-like response for non-autism queries."""
return ("Hi there! I appreciate you reaching out to me. I'm Wisal, and I specialize specifically in autism and Autism Spectrum Disorders. "
"I noticed your question isn't quite related to autism topics. I'd love to help you, but I'm most effective when answering "
"questions about autism, ASD, autism support strategies, therapies, or related concerns.\n\n"
"Could you try asking me something about autism instead? I'm here and ready to help with any autism-related questions you might have! 😊")
def get_non_autism_answer_response() -> str:
"""Return a more human-like response when document answers are not autism-related."""
return ("I'm sorry, but the information I found in the document doesn't seem to be related to autism or Autism Spectrum Disorders. "
"Since I'm Wisal, your autism specialist, I want to make sure I'm providing you with relevant, autism-focused information. "
"Could you try asking a question that's more specifically about autism? I'm here to help with any autism-related topics! 😊") |