ciyidogan commited on
Commit
53db15d
·
verified ·
1 Parent(s): c22dfdb

Update config_provider.py

Browse files
Files changed (1) hide show
  1. config_provider.py +224 -16
config_provider.py CHANGED
@@ -3,14 +3,14 @@ Flare – ConfigProvider (with Provider Abstraction and Multi-language Support)
3
  """
4
 
5
  from __future__ import annotations
6
- import json, os, threading
 
7
  from pathlib import Path
8
- from typing import Any, Dict, List, Optional
9
  from datetime import datetime
10
  import commentjson
11
-
12
- from pydantic import BaseModel, Field, HttpUrl, ValidationError
13
  from utils import log
 
14
  from encryption_utils import decrypt
15
 
16
  # ===================== New Provider Classes =====================
@@ -41,12 +41,47 @@ class LocalizedExample(BaseModel):
41
  locale_code: str
42
  example: str
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # ===================== Global Configuration =====================
45
  class GlobalConfig(BaseModel):
46
  # Provider settings (replaces work_mode, cloud_token, spark_endpoint)
47
  llm_provider: ProviderSettings
48
- tts_provider: ProviderSettings = ProviderSettings(name="no_tts")
49
- stt_provider: ProviderSettings = ProviderSettings(name="no_stt")
50
 
51
  # Available providers
52
  providers: List[ProviderConfig] = []
@@ -75,7 +110,9 @@ class GlobalConfig(BaseModel):
75
  }
76
  provider = provider_map.get(provider_type)
77
  if provider and provider.api_key:
78
- return decrypt(provider.api_key) if provider.api_key else None
 
 
79
  return None
80
 
81
  # Backward compatibility helpers
@@ -85,14 +122,23 @@ class GlobalConfig(BaseModel):
85
 
86
  def is_gpt_mode(self) -> bool:
87
  """Check if using GPT provider"""
88
- return self.llm_provider.name.startswith("gpt4o")
 
 
 
 
 
 
 
 
89
 
90
- # ===================== Other Config Classes =====================
91
  class UserConfig(BaseModel):
92
  username: str
93
  password_hash: str
94
  salt: str
95
 
 
96
  class RetryConfig(BaseModel):
97
  retry_count: int = Field(3, alias="max_attempts")
98
  backoff_seconds: int = 2
@@ -102,6 +148,7 @@ class ProxyConfig(BaseModel):
102
  enabled: bool = True
103
  url: HttpUrl
104
 
 
105
  class APIAuthConfig(BaseModel):
106
  enabled: bool = False
107
  token_endpoint: Optional[HttpUrl] = None
@@ -122,7 +169,7 @@ class APIConfig(BaseModel):
122
  body_template: Dict[str, Any] = {}
123
  timeout_seconds: int = 10
124
  retry: RetryConfig = RetryConfig()
125
- proxy: Optional[str | ProxyConfig] = None
126
  auth: Optional[APIAuthConfig] = None
127
  response_prompt: Optional[str] = None
128
  response_mappings: List[Dict[str, Any]] = []
@@ -136,7 +183,7 @@ class APIConfig(BaseModel):
136
  extra = "allow"
137
  populate_by_name = True
138
 
139
- # ===================== Intent / Parameter =====================
140
  class ParameterConfig(BaseModel):
141
  name: str
142
  caption: List[LocalizedCaption] = [] # Multi-language captions
@@ -207,7 +254,7 @@ class IntentConfig(BaseModel):
207
  # Return all examples if no locale match
208
  return [e.example for e in self.examples]
209
 
210
- # ===================== Version / Project =====================
211
  class LLMConfig(BaseModel):
212
  repo_id: str
213
  generation_config: Dict[str, Any] = {}
@@ -259,7 +306,7 @@ class ProjectConfig(BaseModel):
259
  class Config:
260
  extra = "allow"
261
 
262
- # ===================== Activity Log =====================
263
  class ActivityLogEntry(BaseModel):
264
  timestamp: str
265
  username: str
@@ -269,7 +316,7 @@ class ActivityLogEntry(BaseModel):
269
  entity_name: Optional[str] = None
270
  details: Optional[str] = None
271
 
272
- # ===================== Service Config =====================
273
  class ServiceConfig(BaseModel):
274
  global_config: GlobalConfig = Field(..., alias="config")
275
  projects: List[ProjectConfig]
@@ -285,7 +332,7 @@ class ServiceConfig(BaseModel):
285
  _api_by_name: Dict[str, APIConfig] = {}
286
 
287
  def build_index(self):
288
- self._api_by_name = {a.name: a for a in self.apis}
289
 
290
  def get_api(self, name: str) -> Optional[APIConfig]:
291
  return self._api_by_name.get(name)
@@ -325,7 +372,7 @@ class ServiceConfig(BaseModel):
325
 
326
  log("✅ Configuration saved to service_config.jsonc")
327
 
328
- # ===================== Provider Singleton =====================
329
  class ConfigProvider:
330
  _instance: Optional[ServiceConfig] = None
331
  _CONFIG_PATH = Path(__file__).parent / "service_config.jsonc"
@@ -371,6 +418,10 @@ class ConfigProvider:
371
  if 'config' not in config_data:
372
  config_data['config'] = {}
373
 
 
 
 
 
374
  # Parse API configs specially
375
  if 'apis' in config_data:
376
  for api in config_data['apis']:
@@ -404,6 +455,163 @@ class ConfigProvider:
404
  log(f"❌ Error loading config: {e}")
405
  raise
406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  @classmethod
408
  def _check_environment_setup(cls):
409
  """Check if environment is properly configured"""
 
3
  """
4
 
5
  from __future__ import annotations
6
+ import json, os
7
+ import threading
8
  from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Union
10
  from datetime import datetime
11
  import commentjson
 
 
12
  from utils import log
13
+ from pydantic import BaseModel, Field, HttpUrl, ValidationError, field_validator, validator
14
  from encryption_utils import decrypt
15
 
16
  # ===================== New Provider Classes =====================
 
41
  locale_code: str
42
  example: str
43
 
44
+ # ---------------- Parameter Collection Config ---------
45
+ class ParameterCollectionConfig(BaseModel):
46
+ """Configuration for smart parameter collection"""
47
+ max_params_per_question: int = Field(default=2, ge=1, le=5)
48
+ smart_grouping: bool = Field(default=True)
49
+ retry_unanswered: bool = Field(default=True)
50
+ collection_prompt: str = Field(default="""
51
+ You are a helpful assistant collecting information from the user.
52
+
53
+ Conversation context:
54
+ {{conversation_history}}
55
+
56
+ Intent: {{intent_name}} - {{intent_caption}}
57
+
58
+ Already collected:
59
+ {{collected_params}}
60
+
61
+ Still needed:
62
+ {{missing_params}}
63
+
64
+ Previously asked but not answered:
65
+ {{unanswered_params}}
66
+
67
+ Rules:
68
+ 1. Ask for maximum {{max_params}} parameters in one question
69
+ 2. Group parameters that naturally go together (like from/to cities, dates)
70
+ 3. If some parameters were asked before but not answered, include them again
71
+ 4. Be natural and conversational in {{project_language}}
72
+ 5. Use context from the conversation to make the question flow naturally
73
+
74
+ Generate ONLY the question, nothing else.""")
75
+
76
+ class Config:
77
+ extra = "allow"
78
+
79
  # ===================== Global Configuration =====================
80
  class GlobalConfig(BaseModel):
81
  # Provider settings (replaces work_mode, cloud_token, spark_endpoint)
82
  llm_provider: ProviderSettings
83
+ tts_provider: ProviderSettings = Field(default_factory=lambda: ProviderSettings(name="no_tts"))
84
+ stt_provider: ProviderSettings = Field(default_factory=lambda: ProviderSettings(name="no_stt"))
85
 
86
  # Available providers
87
  providers: List[ProviderConfig] = []
 
110
  }
111
  provider = provider_map.get(provider_type)
112
  if provider and provider.api_key:
113
+ if provider.api_key.startswith("enc:"):
114
+ return decrypt(provider.api_key)
115
+ return provider.api_key
116
  return None
117
 
118
  # Backward compatibility helpers
 
122
 
123
  def is_gpt_mode(self) -> bool:
124
  """Check if using GPT provider"""
125
+ return self.llm_provider.name.startswith("gpt4o") if self.llm_provider else False
126
+
127
+ def get_gpt_model(self) -> str:
128
+ """Get the GPT model name for OpenAI API"""
129
+ if self.llm_provider.name == "gpt4o":
130
+ return "gpt-4o"
131
+ elif self.llm_provider.name == "gpt4o-mini":
132
+ return "gpt-4o-mini"
133
+ return None
134
 
135
+ # ---------------- Global -----------------
136
  class UserConfig(BaseModel):
137
  username: str
138
  password_hash: str
139
  salt: str
140
 
141
+ # ---------------- Retry / Proxy ----------
142
  class RetryConfig(BaseModel):
143
  retry_count: int = Field(3, alias="max_attempts")
144
  backoff_seconds: int = 2
 
148
  enabled: bool = True
149
  url: HttpUrl
150
 
151
+ # ---------------- API & Auth -------------
152
  class APIAuthConfig(BaseModel):
153
  enabled: bool = False
154
  token_endpoint: Optional[HttpUrl] = None
 
169
  body_template: Dict[str, Any] = {}
170
  timeout_seconds: int = 10
171
  retry: RetryConfig = RetryConfig()
172
+ proxy: Optional[Union[str, ProxyConfig]] = None
173
  auth: Optional[APIAuthConfig] = None
174
  response_prompt: Optional[str] = None
175
  response_mappings: List[Dict[str, Any]] = []
 
183
  extra = "allow"
184
  populate_by_name = True
185
 
186
+ # ---------------- Intent / Param ---------
187
  class ParameterConfig(BaseModel):
188
  name: str
189
  caption: List[LocalizedCaption] = [] # Multi-language captions
 
254
  # Return all examples if no locale match
255
  return [e.example for e in self.examples]
256
 
257
+ # ---------------- Version / Project ------
258
  class LLMConfig(BaseModel):
259
  repo_id: str
260
  generation_config: Dict[str, Any] = {}
 
306
  class Config:
307
  extra = "allow"
308
 
309
+ # ---------------- Activity Log -----------
310
  class ActivityLogEntry(BaseModel):
311
  timestamp: str
312
  username: str
 
316
  entity_name: Optional[str] = None
317
  details: Optional[str] = None
318
 
319
+ # ---------------- Service Config ---------
320
  class ServiceConfig(BaseModel):
321
  global_config: GlobalConfig = Field(..., alias="config")
322
  projects: List[ProjectConfig]
 
332
  _api_by_name: Dict[str, APIConfig] = {}
333
 
334
  def build_index(self):
335
+ self._api_by_name = {a.name: a for a in self.apis if not a.deleted}
336
 
337
  def get_api(self, name: str) -> Optional[APIConfig]:
338
  return self._api_by_name.get(name)
 
372
 
373
  log("✅ Configuration saved to service_config.jsonc")
374
 
375
+ # ---------------- Provider Singleton -----
376
  class ConfigProvider:
377
  _instance: Optional[ServiceConfig] = None
378
  _CONFIG_PATH = Path(__file__).parent / "service_config.jsonc"
 
418
  if 'config' not in config_data:
419
  config_data['config'] = {}
420
 
421
+ # Handle backward compatibility - convert old format to new
422
+ if 'work_mode' in config_data.get('config', {}):
423
+ cls._migrate_old_config(config_data)
424
+
425
  # Parse API configs specially
426
  if 'apis' in config_data:
427
  for api in config_data['apis']:
 
455
  log(f"❌ Error loading config: {e}")
456
  raise
457
 
458
+ @classmethod
459
+ def _migrate_old_config(cls, config_data: dict):
460
+ """Migrate old config format to new provider-based format"""
461
+ log("🔄 Migrating old config format to new provider format...")
462
+
463
+ old_config = config_data.get('config', {})
464
+
465
+ # Create default providers if not exists
466
+ if 'providers' not in old_config:
467
+ old_config['providers'] = [
468
+ {
469
+ "type": "llm",
470
+ "name": "spark",
471
+ "display_name": "Spark (HuggingFace)",
472
+ "requires_endpoint": True,
473
+ "requires_api_key": True,
474
+ "requires_repo_info": True
475
+ },
476
+ {
477
+ "type": "llm",
478
+ "name": "gpt4o",
479
+ "display_name": "GPT-4o",
480
+ "requires_endpoint": False,
481
+ "requires_api_key": True,
482
+ "requires_repo_info": False
483
+ },
484
+ {
485
+ "type": "llm",
486
+ "name": "gpt4o-mini",
487
+ "display_name": "GPT-4o Mini",
488
+ "requires_endpoint": False,
489
+ "requires_api_key": True,
490
+ "requires_repo_info": False
491
+ },
492
+ {
493
+ "type": "tts",
494
+ "name": "elevenlabs",
495
+ "display_name": "ElevenLabs",
496
+ "requires_endpoint": False,
497
+ "requires_api_key": True
498
+ },
499
+ {
500
+ "type": "stt",
501
+ "name": "google",
502
+ "display_name": "Google Speech-to-Text",
503
+ "requires_endpoint": False,
504
+ "requires_api_key": True
505
+ }
506
+ ]
507
+
508
+ # Migrate LLM provider
509
+ work_mode = old_config.get('work_mode', 'hfcloud')
510
+ if work_mode in ['gpt4o', 'gpt4o-mini']:
511
+ provider_name = work_mode
512
+ api_key = old_config.get('cloud_token', '')
513
+ endpoint = None
514
+ else:
515
+ provider_name = 'spark'
516
+ api_key = old_config.get('cloud_token', '')
517
+ endpoint = old_config.get('spark_endpoint', '')
518
+
519
+ old_config['llm_provider'] = {
520
+ "name": provider_name,
521
+ "api_key": api_key,
522
+ "endpoint": endpoint,
523
+ "settings": {
524
+ "internal_prompt": old_config.get('internal_prompt', ''),
525
+ "parameter_collection_config": old_config.get('parameter_collection_config', {
526
+ "max_params_per_question": 2,
527
+ "smart_grouping": True,
528
+ "retry_unanswered": True,
529
+ "collection_prompt": ParameterCollectionConfig().collection_prompt
530
+ })
531
+ }
532
+ }
533
+
534
+ # Migrate TTS provider
535
+ tts_engine = old_config.get('tts_engine', 'no_tts')
536
+ old_config['tts_provider'] = {
537
+ "name": tts_engine,
538
+ "api_key": old_config.get('tts_engine_api_key', ''),
539
+ "settings": old_config.get('tts_settings', {})
540
+ }
541
+
542
+ # Migrate STT provider
543
+ stt_engine = old_config.get('stt_engine', 'no_stt')
544
+ old_config['stt_provider'] = {
545
+ "name": stt_engine,
546
+ "api_key": old_config.get('stt_engine_api_key', ''),
547
+ "settings": old_config.get('stt_settings', {})
548
+ }
549
+
550
+ # Migrate projects - language settings
551
+ for project in config_data.get('projects', []):
552
+ if 'default_language' in project:
553
+ # Map language names to locale codes
554
+ lang_to_locale = {
555
+ 'Türkçe': 'tr',
556
+ 'Turkish': 'tr',
557
+ 'English': 'en',
558
+ 'Deutsch': 'de',
559
+ 'German': 'de'
560
+ }
561
+ project['default_locale'] = lang_to_locale.get(project['default_language'], 'tr')
562
+ del project['default_language']
563
+
564
+ if 'supported_languages' in project:
565
+ # Convert to locale codes
566
+ supported_locales = []
567
+ for lang in project['supported_languages']:
568
+ locale = lang_to_locale.get(lang, lang)
569
+ if locale not in supported_locales:
570
+ supported_locales.append(locale)
571
+ project['supported_locales'] = supported_locales
572
+ del project['supported_languages']
573
+
574
+ # Migrate intent examples and parameter captions
575
+ for version in project.get('versions', []):
576
+ for intent in version.get('intents', []):
577
+ # Migrate examples
578
+ if 'examples' in intent and isinstance(intent['examples'], list):
579
+ new_examples = []
580
+ for example in intent['examples']:
581
+ if isinstance(example, str):
582
+ # Old format - use project default locale
583
+ new_examples.append({
584
+ "locale_code": project.get('default_locale', 'tr'),
585
+ "example": example
586
+ })
587
+ elif isinstance(example, dict) and 'locale_code' in example:
588
+ # Already new format
589
+ new_examples.append(example)
590
+ intent['examples'] = new_examples
591
+
592
+ # Migrate parameter captions
593
+ for param in intent.get('parameters', []):
594
+ if 'caption' in param:
595
+ if isinstance(param['caption'], str):
596
+ # Old format - convert to multi-language
597
+ param['caption'] = [{
598
+ "locale_code": project.get('default_locale', 'tr'),
599
+ "caption": param['caption']
600
+ }]
601
+ elif isinstance(param['caption'], list) and param['caption'] and isinstance(param['caption'][0], dict):
602
+ # Already new format
603
+ pass
604
+
605
+ # Remove old fields
606
+ fields_to_remove = ['work_mode', 'cloud_token', 'spark_endpoint', 'internal_prompt',
607
+ 'tts_engine', 'tts_engine_api_key', 'tts_settings',
608
+ 'stt_engine', 'stt_engine_api_key', 'stt_settings',
609
+ 'parameter_collection_config']
610
+ for field in fields_to_remove:
611
+ old_config.pop(field, None)
612
+
613
+ log("✅ Config migration completed")
614
+
615
  @classmethod
616
  def _check_environment_setup(cls):
617
  """Check if environment is properly configured"""