Spaces:
Running
Running
Update config_provider.py
Browse files- config_provider.py +224 -16
config_provider.py
CHANGED
@@ -3,14 +3,14 @@ Flare – ConfigProvider (with Provider Abstraction and Multi-language Support)
|
|
3 |
"""
|
4 |
|
5 |
from __future__ import annotations
|
6 |
-
import json, os
|
|
|
7 |
from pathlib import Path
|
8 |
-
from typing import Any, Dict, List, Optional
|
9 |
from datetime import datetime
|
10 |
import commentjson
|
11 |
-
|
12 |
-
from pydantic import BaseModel, Field, HttpUrl, ValidationError
|
13 |
from utils import log
|
|
|
14 |
from encryption_utils import decrypt
|
15 |
|
16 |
# ===================== New Provider Classes =====================
|
@@ -41,12 +41,47 @@ class LocalizedExample(BaseModel):
|
|
41 |
locale_code: str
|
42 |
example: str
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# ===================== Global Configuration =====================
|
45 |
class GlobalConfig(BaseModel):
|
46 |
# Provider settings (replaces work_mode, cloud_token, spark_endpoint)
|
47 |
llm_provider: ProviderSettings
|
48 |
-
tts_provider: ProviderSettings = ProviderSettings(name="no_tts")
|
49 |
-
stt_provider: ProviderSettings = ProviderSettings(name="no_stt")
|
50 |
|
51 |
# Available providers
|
52 |
providers: List[ProviderConfig] = []
|
@@ -75,7 +110,9 @@ class GlobalConfig(BaseModel):
|
|
75 |
}
|
76 |
provider = provider_map.get(provider_type)
|
77 |
if provider and provider.api_key:
|
78 |
-
|
|
|
|
|
79 |
return None
|
80 |
|
81 |
# Backward compatibility helpers
|
@@ -85,14 +122,23 @@ class GlobalConfig(BaseModel):
|
|
85 |
|
86 |
def is_gpt_mode(self) -> bool:
|
87 |
"""Check if using GPT provider"""
|
88 |
-
return self.llm_provider.name.startswith("gpt4o")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
#
|
91 |
class UserConfig(BaseModel):
|
92 |
username: str
|
93 |
password_hash: str
|
94 |
salt: str
|
95 |
|
|
|
96 |
class RetryConfig(BaseModel):
|
97 |
retry_count: int = Field(3, alias="max_attempts")
|
98 |
backoff_seconds: int = 2
|
@@ -102,6 +148,7 @@ class ProxyConfig(BaseModel):
|
|
102 |
enabled: bool = True
|
103 |
url: HttpUrl
|
104 |
|
|
|
105 |
class APIAuthConfig(BaseModel):
|
106 |
enabled: bool = False
|
107 |
token_endpoint: Optional[HttpUrl] = None
|
@@ -122,7 +169,7 @@ class APIConfig(BaseModel):
|
|
122 |
body_template: Dict[str, Any] = {}
|
123 |
timeout_seconds: int = 10
|
124 |
retry: RetryConfig = RetryConfig()
|
125 |
-
proxy: Optional[str
|
126 |
auth: Optional[APIAuthConfig] = None
|
127 |
response_prompt: Optional[str] = None
|
128 |
response_mappings: List[Dict[str, Any]] = []
|
@@ -136,7 +183,7 @@ class APIConfig(BaseModel):
|
|
136 |
extra = "allow"
|
137 |
populate_by_name = True
|
138 |
|
139 |
-
#
|
140 |
class ParameterConfig(BaseModel):
|
141 |
name: str
|
142 |
caption: List[LocalizedCaption] = [] # Multi-language captions
|
@@ -207,7 +254,7 @@ class IntentConfig(BaseModel):
|
|
207 |
# Return all examples if no locale match
|
208 |
return [e.example for e in self.examples]
|
209 |
|
210 |
-
#
|
211 |
class LLMConfig(BaseModel):
|
212 |
repo_id: str
|
213 |
generation_config: Dict[str, Any] = {}
|
@@ -259,7 +306,7 @@ class ProjectConfig(BaseModel):
|
|
259 |
class Config:
|
260 |
extra = "allow"
|
261 |
|
262 |
-
#
|
263 |
class ActivityLogEntry(BaseModel):
|
264 |
timestamp: str
|
265 |
username: str
|
@@ -269,7 +316,7 @@ class ActivityLogEntry(BaseModel):
|
|
269 |
entity_name: Optional[str] = None
|
270 |
details: Optional[str] = None
|
271 |
|
272 |
-
#
|
273 |
class ServiceConfig(BaseModel):
|
274 |
global_config: GlobalConfig = Field(..., alias="config")
|
275 |
projects: List[ProjectConfig]
|
@@ -285,7 +332,7 @@ class ServiceConfig(BaseModel):
|
|
285 |
_api_by_name: Dict[str, APIConfig] = {}
|
286 |
|
287 |
def build_index(self):
|
288 |
-
self._api_by_name = {a.name: a for a in self.apis}
|
289 |
|
290 |
def get_api(self, name: str) -> Optional[APIConfig]:
|
291 |
return self._api_by_name.get(name)
|
@@ -325,7 +372,7 @@ class ServiceConfig(BaseModel):
|
|
325 |
|
326 |
log("✅ Configuration saved to service_config.jsonc")
|
327 |
|
328 |
-
#
|
329 |
class ConfigProvider:
|
330 |
_instance: Optional[ServiceConfig] = None
|
331 |
_CONFIG_PATH = Path(__file__).parent / "service_config.jsonc"
|
@@ -371,6 +418,10 @@ class ConfigProvider:
|
|
371 |
if 'config' not in config_data:
|
372 |
config_data['config'] = {}
|
373 |
|
|
|
|
|
|
|
|
|
374 |
# Parse API configs specially
|
375 |
if 'apis' in config_data:
|
376 |
for api in config_data['apis']:
|
@@ -404,6 +455,163 @@ class ConfigProvider:
|
|
404 |
log(f"❌ Error loading config: {e}")
|
405 |
raise
|
406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
@classmethod
|
408 |
def _check_environment_setup(cls):
|
409 |
"""Check if environment is properly configured"""
|
|
|
3 |
"""
|
4 |
|
5 |
from __future__ import annotations
|
6 |
+
import json, os
|
7 |
+
import threading
|
8 |
from pathlib import Path
|
9 |
+
from typing import Any, Dict, List, Optional, Union
|
10 |
from datetime import datetime
|
11 |
import commentjson
|
|
|
|
|
12 |
from utils import log
|
13 |
+
from pydantic import BaseModel, Field, HttpUrl, ValidationError, field_validator, validator
|
14 |
from encryption_utils import decrypt
|
15 |
|
16 |
# ===================== New Provider Classes =====================
|
|
|
41 |
locale_code: str
|
42 |
example: str
|
43 |
|
44 |
+
# ---------------- Parameter Collection Config ---------
|
45 |
+
class ParameterCollectionConfig(BaseModel):
|
46 |
+
"""Configuration for smart parameter collection"""
|
47 |
+
max_params_per_question: int = Field(default=2, ge=1, le=5)
|
48 |
+
smart_grouping: bool = Field(default=True)
|
49 |
+
retry_unanswered: bool = Field(default=True)
|
50 |
+
collection_prompt: str = Field(default="""
|
51 |
+
You are a helpful assistant collecting information from the user.
|
52 |
+
|
53 |
+
Conversation context:
|
54 |
+
{{conversation_history}}
|
55 |
+
|
56 |
+
Intent: {{intent_name}} - {{intent_caption}}
|
57 |
+
|
58 |
+
Already collected:
|
59 |
+
{{collected_params}}
|
60 |
+
|
61 |
+
Still needed:
|
62 |
+
{{missing_params}}
|
63 |
+
|
64 |
+
Previously asked but not answered:
|
65 |
+
{{unanswered_params}}
|
66 |
+
|
67 |
+
Rules:
|
68 |
+
1. Ask for maximum {{max_params}} parameters in one question
|
69 |
+
2. Group parameters that naturally go together (like from/to cities, dates)
|
70 |
+
3. If some parameters were asked before but not answered, include them again
|
71 |
+
4. Be natural and conversational in {{project_language}}
|
72 |
+
5. Use context from the conversation to make the question flow naturally
|
73 |
+
|
74 |
+
Generate ONLY the question, nothing else.""")
|
75 |
+
|
76 |
+
class Config:
|
77 |
+
extra = "allow"
|
78 |
+
|
79 |
# ===================== Global Configuration =====================
|
80 |
class GlobalConfig(BaseModel):
|
81 |
# Provider settings (replaces work_mode, cloud_token, spark_endpoint)
|
82 |
llm_provider: ProviderSettings
|
83 |
+
tts_provider: ProviderSettings = Field(default_factory=lambda: ProviderSettings(name="no_tts"))
|
84 |
+
stt_provider: ProviderSettings = Field(default_factory=lambda: ProviderSettings(name="no_stt"))
|
85 |
|
86 |
# Available providers
|
87 |
providers: List[ProviderConfig] = []
|
|
|
110 |
}
|
111 |
provider = provider_map.get(provider_type)
|
112 |
if provider and provider.api_key:
|
113 |
+
if provider.api_key.startswith("enc:"):
|
114 |
+
return decrypt(provider.api_key)
|
115 |
+
return provider.api_key
|
116 |
return None
|
117 |
|
118 |
# Backward compatibility helpers
|
|
|
122 |
|
123 |
def is_gpt_mode(self) -> bool:
|
124 |
"""Check if using GPT provider"""
|
125 |
+
return self.llm_provider.name.startswith("gpt4o") if self.llm_provider else False
|
126 |
+
|
127 |
+
def get_gpt_model(self) -> str:
|
128 |
+
"""Get the GPT model name for OpenAI API"""
|
129 |
+
if self.llm_provider.name == "gpt4o":
|
130 |
+
return "gpt-4o"
|
131 |
+
elif self.llm_provider.name == "gpt4o-mini":
|
132 |
+
return "gpt-4o-mini"
|
133 |
+
return None
|
134 |
|
135 |
+
# ---------------- Global -----------------
|
136 |
class UserConfig(BaseModel):
|
137 |
username: str
|
138 |
password_hash: str
|
139 |
salt: str
|
140 |
|
141 |
+
# ---------------- Retry / Proxy ----------
|
142 |
class RetryConfig(BaseModel):
|
143 |
retry_count: int = Field(3, alias="max_attempts")
|
144 |
backoff_seconds: int = 2
|
|
|
148 |
enabled: bool = True
|
149 |
url: HttpUrl
|
150 |
|
151 |
+
# ---------------- API & Auth -------------
|
152 |
class APIAuthConfig(BaseModel):
|
153 |
enabled: bool = False
|
154 |
token_endpoint: Optional[HttpUrl] = None
|
|
|
169 |
body_template: Dict[str, Any] = {}
|
170 |
timeout_seconds: int = 10
|
171 |
retry: RetryConfig = RetryConfig()
|
172 |
+
proxy: Optional[Union[str, ProxyConfig]] = None
|
173 |
auth: Optional[APIAuthConfig] = None
|
174 |
response_prompt: Optional[str] = None
|
175 |
response_mappings: List[Dict[str, Any]] = []
|
|
|
183 |
extra = "allow"
|
184 |
populate_by_name = True
|
185 |
|
186 |
+
# ---------------- Intent / Param ---------
|
187 |
class ParameterConfig(BaseModel):
|
188 |
name: str
|
189 |
caption: List[LocalizedCaption] = [] # Multi-language captions
|
|
|
254 |
# Return all examples if no locale match
|
255 |
return [e.example for e in self.examples]
|
256 |
|
257 |
+
# ---------------- Version / Project ------
|
258 |
class LLMConfig(BaseModel):
|
259 |
repo_id: str
|
260 |
generation_config: Dict[str, Any] = {}
|
|
|
306 |
class Config:
|
307 |
extra = "allow"
|
308 |
|
309 |
+
# ---------------- Activity Log -----------
|
310 |
class ActivityLogEntry(BaseModel):
|
311 |
timestamp: str
|
312 |
username: str
|
|
|
316 |
entity_name: Optional[str] = None
|
317 |
details: Optional[str] = None
|
318 |
|
319 |
+
# ---------------- Service Config ---------
|
320 |
class ServiceConfig(BaseModel):
|
321 |
global_config: GlobalConfig = Field(..., alias="config")
|
322 |
projects: List[ProjectConfig]
|
|
|
332 |
_api_by_name: Dict[str, APIConfig] = {}
|
333 |
|
334 |
def build_index(self):
|
335 |
+
self._api_by_name = {a.name: a for a in self.apis if not a.deleted}
|
336 |
|
337 |
def get_api(self, name: str) -> Optional[APIConfig]:
|
338 |
return self._api_by_name.get(name)
|
|
|
372 |
|
373 |
log("✅ Configuration saved to service_config.jsonc")
|
374 |
|
375 |
+
# ---------------- Provider Singleton -----
|
376 |
class ConfigProvider:
|
377 |
_instance: Optional[ServiceConfig] = None
|
378 |
_CONFIG_PATH = Path(__file__).parent / "service_config.jsonc"
|
|
|
418 |
if 'config' not in config_data:
|
419 |
config_data['config'] = {}
|
420 |
|
421 |
+
# Handle backward compatibility - convert old format to new
|
422 |
+
if 'work_mode' in config_data.get('config', {}):
|
423 |
+
cls._migrate_old_config(config_data)
|
424 |
+
|
425 |
# Parse API configs specially
|
426 |
if 'apis' in config_data:
|
427 |
for api in config_data['apis']:
|
|
|
455 |
log(f"❌ Error loading config: {e}")
|
456 |
raise
|
457 |
|
458 |
+
@classmethod
|
459 |
+
def _migrate_old_config(cls, config_data: dict):
|
460 |
+
"""Migrate old config format to new provider-based format"""
|
461 |
+
log("🔄 Migrating old config format to new provider format...")
|
462 |
+
|
463 |
+
old_config = config_data.get('config', {})
|
464 |
+
|
465 |
+
# Create default providers if not exists
|
466 |
+
if 'providers' not in old_config:
|
467 |
+
old_config['providers'] = [
|
468 |
+
{
|
469 |
+
"type": "llm",
|
470 |
+
"name": "spark",
|
471 |
+
"display_name": "Spark (HuggingFace)",
|
472 |
+
"requires_endpoint": True,
|
473 |
+
"requires_api_key": True,
|
474 |
+
"requires_repo_info": True
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"type": "llm",
|
478 |
+
"name": "gpt4o",
|
479 |
+
"display_name": "GPT-4o",
|
480 |
+
"requires_endpoint": False,
|
481 |
+
"requires_api_key": True,
|
482 |
+
"requires_repo_info": False
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"type": "llm",
|
486 |
+
"name": "gpt4o-mini",
|
487 |
+
"display_name": "GPT-4o Mini",
|
488 |
+
"requires_endpoint": False,
|
489 |
+
"requires_api_key": True,
|
490 |
+
"requires_repo_info": False
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"type": "tts",
|
494 |
+
"name": "elevenlabs",
|
495 |
+
"display_name": "ElevenLabs",
|
496 |
+
"requires_endpoint": False,
|
497 |
+
"requires_api_key": True
|
498 |
+
},
|
499 |
+
{
|
500 |
+
"type": "stt",
|
501 |
+
"name": "google",
|
502 |
+
"display_name": "Google Speech-to-Text",
|
503 |
+
"requires_endpoint": False,
|
504 |
+
"requires_api_key": True
|
505 |
+
}
|
506 |
+
]
|
507 |
+
|
508 |
+
# Migrate LLM provider
|
509 |
+
work_mode = old_config.get('work_mode', 'hfcloud')
|
510 |
+
if work_mode in ['gpt4o', 'gpt4o-mini']:
|
511 |
+
provider_name = work_mode
|
512 |
+
api_key = old_config.get('cloud_token', '')
|
513 |
+
endpoint = None
|
514 |
+
else:
|
515 |
+
provider_name = 'spark'
|
516 |
+
api_key = old_config.get('cloud_token', '')
|
517 |
+
endpoint = old_config.get('spark_endpoint', '')
|
518 |
+
|
519 |
+
old_config['llm_provider'] = {
|
520 |
+
"name": provider_name,
|
521 |
+
"api_key": api_key,
|
522 |
+
"endpoint": endpoint,
|
523 |
+
"settings": {
|
524 |
+
"internal_prompt": old_config.get('internal_prompt', ''),
|
525 |
+
"parameter_collection_config": old_config.get('parameter_collection_config', {
|
526 |
+
"max_params_per_question": 2,
|
527 |
+
"smart_grouping": True,
|
528 |
+
"retry_unanswered": True,
|
529 |
+
"collection_prompt": ParameterCollectionConfig().collection_prompt
|
530 |
+
})
|
531 |
+
}
|
532 |
+
}
|
533 |
+
|
534 |
+
# Migrate TTS provider
|
535 |
+
tts_engine = old_config.get('tts_engine', 'no_tts')
|
536 |
+
old_config['tts_provider'] = {
|
537 |
+
"name": tts_engine,
|
538 |
+
"api_key": old_config.get('tts_engine_api_key', ''),
|
539 |
+
"settings": old_config.get('tts_settings', {})
|
540 |
+
}
|
541 |
+
|
542 |
+
# Migrate STT provider
|
543 |
+
stt_engine = old_config.get('stt_engine', 'no_stt')
|
544 |
+
old_config['stt_provider'] = {
|
545 |
+
"name": stt_engine,
|
546 |
+
"api_key": old_config.get('stt_engine_api_key', ''),
|
547 |
+
"settings": old_config.get('stt_settings', {})
|
548 |
+
}
|
549 |
+
|
550 |
+
# Migrate projects - language settings
|
551 |
+
for project in config_data.get('projects', []):
|
552 |
+
if 'default_language' in project:
|
553 |
+
# Map language names to locale codes
|
554 |
+
lang_to_locale = {
|
555 |
+
'Türkçe': 'tr',
|
556 |
+
'Turkish': 'tr',
|
557 |
+
'English': 'en',
|
558 |
+
'Deutsch': 'de',
|
559 |
+
'German': 'de'
|
560 |
+
}
|
561 |
+
project['default_locale'] = lang_to_locale.get(project['default_language'], 'tr')
|
562 |
+
del project['default_language']
|
563 |
+
|
564 |
+
if 'supported_languages' in project:
|
565 |
+
# Convert to locale codes
|
566 |
+
supported_locales = []
|
567 |
+
for lang in project['supported_languages']:
|
568 |
+
locale = lang_to_locale.get(lang, lang)
|
569 |
+
if locale not in supported_locales:
|
570 |
+
supported_locales.append(locale)
|
571 |
+
project['supported_locales'] = supported_locales
|
572 |
+
del project['supported_languages']
|
573 |
+
|
574 |
+
# Migrate intent examples and parameter captions
|
575 |
+
for version in project.get('versions', []):
|
576 |
+
for intent in version.get('intents', []):
|
577 |
+
# Migrate examples
|
578 |
+
if 'examples' in intent and isinstance(intent['examples'], list):
|
579 |
+
new_examples = []
|
580 |
+
for example in intent['examples']:
|
581 |
+
if isinstance(example, str):
|
582 |
+
# Old format - use project default locale
|
583 |
+
new_examples.append({
|
584 |
+
"locale_code": project.get('default_locale', 'tr'),
|
585 |
+
"example": example
|
586 |
+
})
|
587 |
+
elif isinstance(example, dict) and 'locale_code' in example:
|
588 |
+
# Already new format
|
589 |
+
new_examples.append(example)
|
590 |
+
intent['examples'] = new_examples
|
591 |
+
|
592 |
+
# Migrate parameter captions
|
593 |
+
for param in intent.get('parameters', []):
|
594 |
+
if 'caption' in param:
|
595 |
+
if isinstance(param['caption'], str):
|
596 |
+
# Old format - convert to multi-language
|
597 |
+
param['caption'] = [{
|
598 |
+
"locale_code": project.get('default_locale', 'tr'),
|
599 |
+
"caption": param['caption']
|
600 |
+
}]
|
601 |
+
elif isinstance(param['caption'], list) and param['caption'] and isinstance(param['caption'][0], dict):
|
602 |
+
# Already new format
|
603 |
+
pass
|
604 |
+
|
605 |
+
# Remove old fields
|
606 |
+
fields_to_remove = ['work_mode', 'cloud_token', 'spark_endpoint', 'internal_prompt',
|
607 |
+
'tts_engine', 'tts_engine_api_key', 'tts_settings',
|
608 |
+
'stt_engine', 'stt_engine_api_key', 'stt_settings',
|
609 |
+
'parameter_collection_config']
|
610 |
+
for field in fields_to_remove:
|
611 |
+
old_config.pop(field, None)
|
612 |
+
|
613 |
+
log("✅ Config migration completed")
|
614 |
+
|
615 |
@classmethod
|
616 |
def _check_environment_setup(cls):
|
617 |
"""Check if environment is properly configured"""
|