Spaces:
Sleeping
Sleeping
YanBoChen
commited on
Commit
Β·
9b1dc9a
1
Parent(s):
278c9ff
feat: implement flexible condition extraction and regex matching in user queries, to make fallback userprompt more robust
Browse files- src/medical_conditions.py +52 -5
- src/user_prompt.py +106 -25
src/medical_conditions.py
CHANGED
|
@@ -5,13 +5,26 @@ This module provides centralized configuration for:
|
|
| 5 |
1. Predefined medical conditions
|
| 6 |
2. Condition-to-keyword mappings
|
| 7 |
3. Fallback condition keywords
|
|
|
|
| 8 |
|
| 9 |
Author: OnCall.ai Team
|
| 10 |
Date: 2025-07-29
|
| 11 |
"""
|
| 12 |
|
| 13 |
from typing import Dict, Optional
|
|
|
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# Comprehensive Condition-to-Keyword Mapping
|
| 16 |
CONDITION_KEYWORD_MAPPING: Dict[str, Dict[str, str]] = {
|
| 17 |
"acute myocardial infarction": {
|
|
@@ -72,7 +85,7 @@ def get_condition_keywords(specific_condition: str) -> Optional[str]:
|
|
| 72 |
|
| 73 |
def validate_condition(condition: str) -> bool:
|
| 74 |
"""
|
| 75 |
-
Check if a condition exists in our predefined mapping
|
| 76 |
|
| 77 |
Args:
|
| 78 |
condition: Medical condition to validate
|
|
@@ -80,11 +93,31 @@ def validate_condition(condition: str) -> bool:
|
|
| 80 |
Returns:
|
| 81 |
Boolean indicating condition validity
|
| 82 |
"""
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
def get_condition_details(condition: str) -> Optional[Dict[str, str]]:
|
| 86 |
"""
|
| 87 |
-
Retrieve detailed information for a specific condition
|
| 88 |
|
| 89 |
Args:
|
| 90 |
condition: Medical condition name
|
|
@@ -92,8 +125,22 @@ def get_condition_details(condition: str) -> Optional[Dict[str, str]]:
|
|
| 92 |
Returns:
|
| 93 |
Dict with emergency and treatment keywords, or None
|
| 94 |
"""
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
for key, value in CONDITION_KEYWORD_MAPPING.items():
|
| 97 |
-
if key.lower() ==
|
| 98 |
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
return None
|
|
|
|
| 5 |
1. Predefined medical conditions
|
| 6 |
2. Condition-to-keyword mappings
|
| 7 |
3. Fallback condition keywords
|
| 8 |
+
4. Regular expression matching for flexible condition recognition
|
| 9 |
|
| 10 |
Author: OnCall.ai Team
|
| 11 |
Date: 2025-07-29
|
| 12 |
"""
|
| 13 |
|
| 14 |
from typing import Dict, Optional
|
| 15 |
+
import re
|
| 16 |
|
| 17 |
+
# Regular Expression Mapping for Flexible Condition Recognition
|
| 18 |
+
CONDITION_REGEX_MAPPING: Dict[str, str] = {
|
| 19 |
+
r"acute[\s_-]*coronary[\s_-]*syndrome": "acute_coronary_syndrome",
|
| 20 |
+
r"acute[\s_-]*myocardial[\s_-]*infarction": "acute myocardial infarction",
|
| 21 |
+
r"acute[\s_-]*ischemic[\s_-]*stroke": "acute_ischemic_stroke",
|
| 22 |
+
r"hemorrhagic[\s_-]*stroke": "hemorrhagic_stroke",
|
| 23 |
+
r"transient[\s_-]*ischemic[\s_-]*attack": "transient_ischemic_attack",
|
| 24 |
+
r"pulmonary[\s_-]*embolism": "pulmonary embolism",
|
| 25 |
+
# Handles variants like:
|
| 26 |
+
# "Acute Coronary Syndrome", "acute_coronary_syndrome", "acute-coronary-syndrome"
|
| 27 |
+
}
|
| 28 |
# Comprehensive Condition-to-Keyword Mapping
|
| 29 |
CONDITION_KEYWORD_MAPPING: Dict[str, Dict[str, str]] = {
|
| 30 |
"acute myocardial infarction": {
|
|
|
|
| 85 |
|
| 86 |
def validate_condition(condition: str) -> bool:
|
| 87 |
"""
|
| 88 |
+
Check if a condition exists in our predefined mapping with flexible regex matching
|
| 89 |
|
| 90 |
Args:
|
| 91 |
condition: Medical condition to validate
|
|
|
|
| 93 |
Returns:
|
| 94 |
Boolean indicating condition validity
|
| 95 |
"""
|
| 96 |
+
if not condition:
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
condition_lower = condition.lower().strip()
|
| 100 |
+
|
| 101 |
+
# Level 1: Direct exact match (fastest)
|
| 102 |
+
for key in CONDITION_KEYWORD_MAPPING.keys():
|
| 103 |
+
if key.lower() == condition_lower:
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
# Level 2: Regular expression matching (flexible)
|
| 107 |
+
for regex_pattern, mapped_condition in CONDITION_REGEX_MAPPING.items():
|
| 108 |
+
if re.search(regex_pattern, condition_lower, re.IGNORECASE):
|
| 109 |
+
return True
|
| 110 |
+
|
| 111 |
+
# Level 3: Partial matching for key medical terms (fallback)
|
| 112 |
+
medical_keywords = ['coronary', 'syndrome', 'stroke', 'myocardial', 'embolism', 'ischemic']
|
| 113 |
+
if any(keyword in condition_lower for keyword in medical_keywords):
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
return False
|
| 117 |
|
| 118 |
def get_condition_details(condition: str) -> Optional[Dict[str, str]]:
|
| 119 |
"""
|
| 120 |
+
Retrieve detailed information for a specific condition with flexible matching
|
| 121 |
|
| 122 |
Args:
|
| 123 |
condition: Medical condition name
|
|
|
|
| 125 |
Returns:
|
| 126 |
Dict with emergency and treatment keywords, or None
|
| 127 |
"""
|
| 128 |
+
if not condition:
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
condition_lower = condition.lower().strip()
|
| 132 |
+
|
| 133 |
+
# Level 1: Direct exact match
|
| 134 |
for key, value in CONDITION_KEYWORD_MAPPING.items():
|
| 135 |
+
if key.lower() == condition_lower:
|
| 136 |
return value
|
| 137 |
+
|
| 138 |
+
# Level 2: Regular expression matching
|
| 139 |
+
for regex_pattern, mapped_condition in CONDITION_REGEX_MAPPING.items():
|
| 140 |
+
if re.search(regex_pattern, condition_lower, re.IGNORECASE):
|
| 141 |
+
# Find the mapped condition in the keyword mapping
|
| 142 |
+
for key, value in CONDITION_KEYWORD_MAPPING.items():
|
| 143 |
+
if key.lower() == mapped_condition.lower():
|
| 144 |
+
return value
|
| 145 |
+
|
| 146 |
return None
|
src/user_prompt.py
CHANGED
|
@@ -22,6 +22,7 @@ import re # Added missing import for re
|
|
| 22 |
# Import our centralized medical conditions configuration
|
| 23 |
from medical_conditions import (
|
| 24 |
CONDITION_KEYWORD_MAPPING,
|
|
|
|
| 25 |
get_condition_details,
|
| 26 |
validate_condition
|
| 27 |
)
|
|
@@ -51,6 +52,48 @@ class UserPromptProcessor:
|
|
| 51 |
|
| 52 |
logger.info("UserPromptProcessor initialized")
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def extract_condition_keywords(self, user_query: str) -> Dict[str, str]:
|
| 55 |
"""
|
| 56 |
Extract condition keywords with multi-level fallback
|
|
@@ -61,36 +104,54 @@ class UserPromptProcessor:
|
|
| 61 |
Returns:
|
| 62 |
Dict with condition and keywords
|
| 63 |
"""
|
|
|
|
| 64 |
|
| 65 |
# Level 1: Predefined Mapping (Fast Path)
|
|
|
|
| 66 |
predefined_result = self._predefined_mapping(user_query)
|
| 67 |
if predefined_result:
|
|
|
|
| 68 |
return predefined_result
|
|
|
|
| 69 |
|
| 70 |
# Level 2: Llama3-Med42-70B Extraction (if available)
|
|
|
|
| 71 |
if self.llm_client:
|
| 72 |
llm_result = self._extract_with_llm(user_query)
|
| 73 |
if llm_result:
|
|
|
|
| 74 |
return llm_result
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# Level 3: Semantic Search Fallback
|
|
|
|
| 77 |
semantic_result = self._semantic_search_fallback(user_query)
|
| 78 |
if semantic_result:
|
|
|
|
| 79 |
return semantic_result
|
|
|
|
| 80 |
|
| 81 |
# Level 4: Medical Query Validation
|
|
|
|
| 82 |
# Only validate if previous levels failed - speed optimization
|
| 83 |
validation_result = self.validate_medical_query(user_query)
|
| 84 |
if validation_result: # If validation fails (returns non-None)
|
|
|
|
| 85 |
return validation_result
|
|
|
|
| 86 |
|
| 87 |
# Level 5: Generic Medical Search (after validation passes)
|
|
|
|
| 88 |
generic_result = self._generic_medical_search(user_query)
|
| 89 |
if generic_result:
|
|
|
|
| 90 |
return generic_result
|
|
|
|
| 91 |
|
| 92 |
# No match found
|
| 93 |
-
|
| 94 |
return {
|
| 95 |
'condition': '',
|
| 96 |
'emergency_keywords': '',
|
|
@@ -99,7 +160,7 @@ class UserPromptProcessor:
|
|
| 99 |
|
| 100 |
def _predefined_mapping(self, user_query: str) -> Optional[Dict[str, str]]:
|
| 101 |
"""
|
| 102 |
-
Fast predefined condition mapping
|
| 103 |
|
| 104 |
Args:
|
| 105 |
user_query: User's medical query
|
|
@@ -107,15 +168,18 @@ class UserPromptProcessor:
|
|
| 107 |
Returns:
|
| 108 |
Mapped condition keywords or None
|
| 109 |
"""
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
return {
|
| 116 |
'condition': condition,
|
| 117 |
-
'emergency_keywords':
|
| 118 |
-
'treatment_keywords':
|
| 119 |
}
|
| 120 |
|
| 121 |
return None
|
|
@@ -140,16 +204,22 @@ class UserPromptProcessor:
|
|
| 140 |
timeout=2.0
|
| 141 |
)
|
| 142 |
|
| 143 |
-
|
|
|
|
| 144 |
|
| 145 |
-
if
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
return None
|
| 155 |
|
|
@@ -241,8 +311,7 @@ class UserPromptProcessor:
|
|
| 241 |
generic_results = self.retrieval_system.search_generic_medical_content(generic_query)
|
| 242 |
|
| 243 |
if generic_results:
|
| 244 |
-
return
|
| 245 |
-
{
|
| 246 |
'condition': 'generic medical query',
|
| 247 |
'emergency_keywords': 'medical|emergency',
|
| 248 |
'treatment_keywords': 'treatment|management',
|
|
@@ -256,7 +325,7 @@ class UserPromptProcessor:
|
|
| 256 |
|
| 257 |
def _infer_condition_from_text(self, text: str) -> Optional[str]:
|
| 258 |
"""
|
| 259 |
-
Infer medical condition from text using
|
| 260 |
|
| 261 |
Args:
|
| 262 |
text: Input medical text
|
|
@@ -264,20 +333,32 @@ class UserPromptProcessor:
|
|
| 264 |
Returns:
|
| 265 |
Inferred condition or None
|
| 266 |
"""
|
| 267 |
-
# Implement
|
| 268 |
-
# This is a placeholder and would need more sophisticated implementation
|
| 269 |
conditions = list(CONDITION_KEYWORD_MAPPING.keys())
|
| 270 |
text_embedding = self.embedding_model.encode(text)
|
| 271 |
condition_embeddings = [self.embedding_model.encode(condition) for condition in conditions]
|
| 272 |
|
|
|
|
| 273 |
similarities = [
|
| 274 |
np.dot(text_embedding, condition_emb) /
|
| 275 |
(np.linalg.norm(text_embedding) * np.linalg.norm(condition_emb))
|
| 276 |
for condition_emb in condition_embeddings
|
| 277 |
]
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
def validate_keywords(self, keywords: Dict[str, str]) -> bool:
|
| 283 |
"""
|
|
|
|
| 22 |
# Import our centralized medical conditions configuration
|
| 23 |
from medical_conditions import (
|
| 24 |
CONDITION_KEYWORD_MAPPING,
|
| 25 |
+
CONDITION_REGEX_MAPPING,
|
| 26 |
get_condition_details,
|
| 27 |
validate_condition
|
| 28 |
)
|
|
|
|
| 52 |
|
| 53 |
logger.info("UserPromptProcessor initialized")
|
| 54 |
|
| 55 |
+
def _extract_condition_from_query(self, user_query: str) -> Optional[str]:
|
| 56 |
+
"""
|
| 57 |
+
Unified condition extraction with flexible matching
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
user_query: User's medical query
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Standardized condition name or None
|
| 64 |
+
"""
|
| 65 |
+
if not user_query:
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
query_lower = user_query.lower().strip()
|
| 69 |
+
|
| 70 |
+
# Level 1: Direct exact matching (fastest)
|
| 71 |
+
for condition in CONDITION_KEYWORD_MAPPING.keys():
|
| 72 |
+
if condition.lower() in query_lower:
|
| 73 |
+
logger.info(f"π― Direct match found: {condition}")
|
| 74 |
+
return condition
|
| 75 |
+
|
| 76 |
+
# Level 2: Regular expression matching (flexible)
|
| 77 |
+
for regex_pattern, mapped_condition in CONDITION_REGEX_MAPPING.items():
|
| 78 |
+
if re.search(regex_pattern, query_lower, re.IGNORECASE):
|
| 79 |
+
logger.info(f"π― Regex match found: {regex_pattern} β {mapped_condition}")
|
| 80 |
+
return mapped_condition
|
| 81 |
+
|
| 82 |
+
# Level 3: Partial keyword matching (fallback)
|
| 83 |
+
medical_keywords_mapping = {
|
| 84 |
+
'coronary': 'acute_coronary_syndrome',
|
| 85 |
+
'myocardial': 'acute myocardial infarction',
|
| 86 |
+
'stroke': 'acute stroke',
|
| 87 |
+
'embolism': 'pulmonary embolism'
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
for keyword, condition in medical_keywords_mapping.items():
|
| 91 |
+
if keyword in query_lower:
|
| 92 |
+
logger.info(f"π― Keyword match found: {keyword} β {condition}")
|
| 93 |
+
return condition
|
| 94 |
+
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
def extract_condition_keywords(self, user_query: str) -> Dict[str, str]:
|
| 98 |
"""
|
| 99 |
Extract condition keywords with multi-level fallback
|
|
|
|
| 104 |
Returns:
|
| 105 |
Dict with condition and keywords
|
| 106 |
"""
|
| 107 |
+
logger.info(f"π Starting condition extraction for query: '{user_query}'")
|
| 108 |
|
| 109 |
# Level 1: Predefined Mapping (Fast Path)
|
| 110 |
+
logger.info("π LEVEL 1: Attempting predefined mapping...")
|
| 111 |
predefined_result = self._predefined_mapping(user_query)
|
| 112 |
if predefined_result:
|
| 113 |
+
logger.info("β
LEVEL 1: SUCCESS - Found predefined mapping")
|
| 114 |
return predefined_result
|
| 115 |
+
logger.info("β LEVEL 1: FAILED - No predefined mapping found")
|
| 116 |
|
| 117 |
# Level 2: Llama3-Med42-70B Extraction (if available)
|
| 118 |
+
logger.info("π LEVEL 2: Attempting LLM extraction...")
|
| 119 |
if self.llm_client:
|
| 120 |
llm_result = self._extract_with_llm(user_query)
|
| 121 |
if llm_result:
|
| 122 |
+
logger.info("β
LEVEL 2: SUCCESS - LLM extraction successful")
|
| 123 |
return llm_result
|
| 124 |
+
logger.info("β LEVEL 2: FAILED - LLM extraction failed")
|
| 125 |
+
else:
|
| 126 |
+
logger.info("βοΈ LEVEL 2: SKIPPED - No LLM client available")
|
| 127 |
|
| 128 |
# Level 3: Semantic Search Fallback
|
| 129 |
+
logger.info("π LEVEL 3: Attempting semantic search...")
|
| 130 |
semantic_result = self._semantic_search_fallback(user_query)
|
| 131 |
if semantic_result:
|
| 132 |
+
logger.info("β
LEVEL 3: SUCCESS - Semantic search successful")
|
| 133 |
return semantic_result
|
| 134 |
+
logger.info("β LEVEL 3: FAILED - Semantic search failed")
|
| 135 |
|
| 136 |
# Level 4: Medical Query Validation
|
| 137 |
+
logger.info("π LEVEL 4: Validating medical query...")
|
| 138 |
# Only validate if previous levels failed - speed optimization
|
| 139 |
validation_result = self.validate_medical_query(user_query)
|
| 140 |
if validation_result: # If validation fails (returns non-None)
|
| 141 |
+
logger.info("β LEVEL 4: FAILED - Query identified as non-medical")
|
| 142 |
return validation_result
|
| 143 |
+
logger.info("β
LEVEL 4: PASSED - Query validated as medical, continuing...")
|
| 144 |
|
| 145 |
# Level 5: Generic Medical Search (after validation passes)
|
| 146 |
+
logger.info("π LEVEL 5: Attempting generic medical search...")
|
| 147 |
generic_result = self._generic_medical_search(user_query)
|
| 148 |
if generic_result:
|
| 149 |
+
logger.info("β
LEVEL 5: SUCCESS - Generic medical search successful")
|
| 150 |
return generic_result
|
| 151 |
+
logger.info("β LEVEL 5: FAILED - Generic medical search failed")
|
| 152 |
|
| 153 |
# No match found
|
| 154 |
+
logger.warning("π« ALL LEVELS FAILED - Returning empty result")
|
| 155 |
return {
|
| 156 |
'condition': '',
|
| 157 |
'emergency_keywords': '',
|
|
|
|
| 160 |
|
| 161 |
def _predefined_mapping(self, user_query: str) -> Optional[Dict[str, str]]:
|
| 162 |
"""
|
| 163 |
+
Fast predefined condition mapping using unified extraction
|
| 164 |
|
| 165 |
Args:
|
| 166 |
user_query: User's medical query
|
|
|
|
| 168 |
Returns:
|
| 169 |
Mapped condition keywords or None
|
| 170 |
"""
|
| 171 |
+
# Use unified condition extraction
|
| 172 |
+
condition = self._extract_condition_from_query(user_query)
|
| 173 |
+
|
| 174 |
+
if condition:
|
| 175 |
+
# Get condition details using the flexible matching
|
| 176 |
+
condition_details = get_condition_details(condition)
|
| 177 |
+
if condition_details:
|
| 178 |
+
logger.info(f"β
Level 1 matched condition: {condition}")
|
| 179 |
return {
|
| 180 |
'condition': condition,
|
| 181 |
+
'emergency_keywords': condition_details['emergency'],
|
| 182 |
+
'treatment_keywords': condition_details['treatment']
|
| 183 |
}
|
| 184 |
|
| 185 |
return None
|
|
|
|
| 204 |
timeout=2.0
|
| 205 |
)
|
| 206 |
|
| 207 |
+
llm_extracted_condition = llama_response.get('extracted_condition', '')
|
| 208 |
+
logger.info(f"π€ LLM extracted condition: {llm_extracted_condition}")
|
| 209 |
|
| 210 |
+
if llm_extracted_condition:
|
| 211 |
+
# Use unified condition extraction for validation and standardization
|
| 212 |
+
standardized_condition = self._extract_condition_from_query(llm_extracted_condition)
|
| 213 |
+
|
| 214 |
+
if standardized_condition:
|
| 215 |
+
condition_details = get_condition_details(standardized_condition)
|
| 216 |
+
if condition_details:
|
| 217 |
+
logger.info(f"β
Level 2 standardized condition: {standardized_condition}")
|
| 218 |
+
return {
|
| 219 |
+
'condition': standardized_condition,
|
| 220 |
+
'emergency_keywords': condition_details['emergency'],
|
| 221 |
+
'treatment_keywords': condition_details['treatment']
|
| 222 |
+
}
|
| 223 |
|
| 224 |
return None
|
| 225 |
|
|
|
|
| 311 |
generic_results = self.retrieval_system.search_generic_medical_content(generic_query)
|
| 312 |
|
| 313 |
if generic_results:
|
| 314 |
+
return {
|
|
|
|
| 315 |
'condition': 'generic medical query',
|
| 316 |
'emergency_keywords': 'medical|emergency',
|
| 317 |
'treatment_keywords': 'treatment|management',
|
|
|
|
| 325 |
|
| 326 |
def _infer_condition_from_text(self, text: str) -> Optional[str]:
|
| 327 |
"""
|
| 328 |
+
Infer medical condition from text using angular distance
|
| 329 |
|
| 330 |
Args:
|
| 331 |
text: Input medical text
|
|
|
|
| 333 |
Returns:
|
| 334 |
Inferred condition or None
|
| 335 |
"""
|
| 336 |
+
# Implement condition inference using angular distance (consistent with retrieval system)
|
|
|
|
| 337 |
conditions = list(CONDITION_KEYWORD_MAPPING.keys())
|
| 338 |
text_embedding = self.embedding_model.encode(text)
|
| 339 |
condition_embeddings = [self.embedding_model.encode(condition) for condition in conditions]
|
| 340 |
|
| 341 |
+
# Calculate cosine similarities first
|
| 342 |
similarities = [
|
| 343 |
np.dot(text_embedding, condition_emb) /
|
| 344 |
(np.linalg.norm(text_embedding) * np.linalg.norm(condition_emb))
|
| 345 |
for condition_emb in condition_embeddings
|
| 346 |
]
|
| 347 |
|
| 348 |
+
# Convert to angular distances
|
| 349 |
+
angular_distances = [np.arccos(np.clip(sim, -1, 1)) for sim in similarities]
|
| 350 |
+
|
| 351 |
+
# Find minimum angular distance (most similar)
|
| 352 |
+
min_distance_index = np.argmin(angular_distances)
|
| 353 |
+
min_distance = angular_distances[min_distance_index]
|
| 354 |
+
|
| 355 |
+
# Use angular distance threshold of 1.0 (approximately 57 degrees)
|
| 356 |
+
if min_distance < 1.0:
|
| 357 |
+
logger.info(f"Condition inferred: {conditions[min_distance_index]}, angular distance: {min_distance:.3f}")
|
| 358 |
+
return conditions[min_distance_index]
|
| 359 |
+
else:
|
| 360 |
+
logger.info(f"No condition found within angular distance threshold. Min distance: {min_distance:.3f}")
|
| 361 |
+
return None
|
| 362 |
|
| 363 |
def validate_keywords(self, keywords: Dict[str, str]) -> bool:
|
| 364 |
"""
|