Spaces:
Running
Running
Update config_provider.py
Browse files- config_provider.py +150 -158
config_provider.py
CHANGED
@@ -13,7 +13,7 @@ from utils import log
|
|
13 |
from pydantic import BaseModel, Field, HttpUrl, ValidationError, field_validator, validator
|
14 |
from encryption_utils import decrypt
|
15 |
|
16 |
-
# =====================
|
17 |
class ProviderConfig(BaseModel):
|
18 |
"""Provider definition with requirements"""
|
19 |
type: str = Field(..., pattern=r"^(llm|tts|stt)$")
|
@@ -41,104 +41,11 @@ class LocalizedExample(BaseModel):
|
|
41 |
locale_code: str
|
42 |
example: str
|
43 |
|
44 |
-
# ---------------- Parameter Collection Config ---------
|
45 |
-
class ParameterCollectionConfig(BaseModel):
|
46 |
-
"""Configuration for smart parameter collection"""
|
47 |
-
max_params_per_question: int = Field(default=2, ge=1, le=5)
|
48 |
-
smart_grouping: bool = Field(default=True)
|
49 |
-
retry_unanswered: bool = Field(default=True)
|
50 |
-
collection_prompt: str = Field(default="""
|
51 |
-
You are a helpful assistant collecting information from the user.
|
52 |
-
|
53 |
-
Conversation context:
|
54 |
-
{{conversation_history}}
|
55 |
-
|
56 |
-
Intent: {{intent_name}} - {{intent_caption}}
|
57 |
-
|
58 |
-
Already collected:
|
59 |
-
{{collected_params}}
|
60 |
-
|
61 |
-
Still needed:
|
62 |
-
{{missing_params}}
|
63 |
-
|
64 |
-
Previously asked but not answered:
|
65 |
-
{{unanswered_params}}
|
66 |
-
|
67 |
-
Rules:
|
68 |
-
1. Ask for maximum {{max_params}} parameters in one question
|
69 |
-
2. Group parameters that naturally go together (like from/to cities, dates)
|
70 |
-
3. If some parameters were asked before but not answered, include them again
|
71 |
-
4. Be natural and conversational in {{project_language}}
|
72 |
-
5. Use context from the conversation to make the question flow naturally
|
73 |
-
|
74 |
-
Generate ONLY the question, nothing else.""")
|
75 |
-
|
76 |
-
class Config:
|
77 |
-
extra = "allow"
|
78 |
-
|
79 |
-
# ===================== Global Configuration =====================
|
80 |
-
class GlobalConfig(BaseModel):
|
81 |
-
# Provider settings (replaces work_mode, cloud_token, spark_endpoint)
|
82 |
-
llm_provider: ProviderSettings
|
83 |
-
tts_provider: ProviderSettings = Field(default_factory=lambda: ProviderSettings(name="no_tts"))
|
84 |
-
stt_provider: ProviderSettings = Field(default_factory=lambda: ProviderSettings(name="no_stt"))
|
85 |
-
|
86 |
-
# Available providers
|
87 |
-
providers: List[ProviderConfig] = []
|
88 |
-
|
89 |
-
# User management
|
90 |
-
users: List["UserConfig"] = []
|
91 |
-
|
92 |
-
# Helper methods for providers
|
93 |
-
def get_provider_config(self, provider_type: str, provider_name: str) -> Optional[ProviderConfig]:
|
94 |
-
"""Get provider configuration by type and name"""
|
95 |
-
return next(
|
96 |
-
(p for p in self.providers if p.type == provider_type and p.name == provider_name),
|
97 |
-
None
|
98 |
-
)
|
99 |
-
|
100 |
-
def get_providers_by_type(self, provider_type: str) -> List[ProviderConfig]:
|
101 |
-
"""Get all providers of a specific type"""
|
102 |
-
return [p for p in self.providers if p.type == provider_type]
|
103 |
-
|
104 |
-
def get_plain_api_key(self, provider_type: str) -> Optional[str]:
|
105 |
-
"""Get decrypted API key for a provider type"""
|
106 |
-
provider_map = {
|
107 |
-
"llm": self.llm_provider,
|
108 |
-
"tts": self.tts_provider,
|
109 |
-
"stt": self.stt_provider
|
110 |
-
}
|
111 |
-
provider = provider_map.get(provider_type)
|
112 |
-
if provider and provider.api_key:
|
113 |
-
if provider.api_key.startswith("enc:"):
|
114 |
-
return decrypt(provider.api_key)
|
115 |
-
return provider.api_key
|
116 |
-
return None
|
117 |
-
|
118 |
-
# Backward compatibility helpers
|
119 |
-
def is_cloud_mode(self) -> bool:
|
120 |
-
"""Check if running in cloud mode (HuggingFace)"""
|
121 |
-
return bool(os.environ.get("SPACE_ID"))
|
122 |
-
|
123 |
-
def is_gpt_mode(self) -> bool:
|
124 |
-
"""Check if using GPT provider"""
|
125 |
-
return self.llm_provider.name.startswith("gpt4o") if self.llm_provider else False
|
126 |
-
|
127 |
-
def get_gpt_model(self) -> str:
|
128 |
-
"""Get the GPT model name for OpenAI API"""
|
129 |
-
if self.llm_provider.name == "gpt4o":
|
130 |
-
return "gpt-4o"
|
131 |
-
elif self.llm_provider.name == "gpt4o-mini":
|
132 |
-
return "gpt-4o-mini"
|
133 |
-
return None
|
134 |
-
|
135 |
-
# ---------------- Global -----------------
|
136 |
class UserConfig(BaseModel):
|
137 |
username: str
|
138 |
password_hash: str
|
139 |
salt: str
|
140 |
|
141 |
-
# ---------------- Retry / Proxy ----------
|
142 |
class RetryConfig(BaseModel):
|
143 |
retry_count: int = Field(3, alias="max_attempts")
|
144 |
backoff_seconds: int = 2
|
@@ -148,7 +55,6 @@ class ProxyConfig(BaseModel):
|
|
148 |
enabled: bool = True
|
149 |
url: HttpUrl
|
150 |
|
151 |
-
# ---------------- API & Auth -------------
|
152 |
class APIAuthConfig(BaseModel):
|
153 |
enabled: bool = False
|
154 |
token_endpoint: Optional[HttpUrl] = None
|
@@ -184,8 +90,100 @@ class APIConfig(BaseModel):
|
|
184 |
extra = "allow"
|
185 |
populate_by_name = True
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
-
# ---------------- Intent / Param ---------
|
189 |
class ParameterConfig(BaseModel):
|
190 |
name: str
|
191 |
caption: List[LocalizedCaption] = [] # Multi-language captions
|
@@ -229,6 +227,7 @@ class IntentConfig(BaseModel):
|
|
229 |
name: str
|
230 |
caption: Optional[str] = ""
|
231 |
dependencies: List[str] = []
|
|
|
232 |
examples: List[LocalizedExample] = []
|
233 |
detection_prompt: Optional[str] = None
|
234 |
parameters: List[ParameterConfig] = []
|
@@ -256,67 +255,61 @@ class IntentConfig(BaseModel):
|
|
256 |
# Return all examples if no locale match
|
257 |
return [e.example for e in self.examples]
|
258 |
|
259 |
-
#
|
260 |
-
class
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
username: str
|
315 |
-
action: str
|
316 |
-
entity_type: str
|
317 |
-
entity_id: Optional[int] = None
|
318 |
-
entity_name: Optional[str] = None
|
319 |
-
details: Optional[str] = None
|
320 |
|
321 |
# ---------------- Service Config ---------
|
322 |
class ServiceConfig(BaseModel):
|
@@ -556,7 +549,6 @@ class ConfigProvider:
|
|
556 |
"internal_prompt": old_config.get('internal_prompt', ''),
|
557 |
"parameter_collection_config": old_config.get('parameter_collection_config', {
|
558 |
"max_params_per_question": 2,
|
559 |
-
"smart_grouping": True,
|
560 |
"retry_unanswered": True,
|
561 |
"collection_prompt": ParameterCollectionConfig().collection_prompt
|
562 |
})
|
|
|
13 |
from pydantic import BaseModel, Field, HttpUrl, ValidationError, field_validator, validator
|
14 |
from encryption_utils import decrypt
|
15 |
|
16 |
+
# ===================== Models =====================
|
17 |
class ProviderConfig(BaseModel):
|
18 |
"""Provider definition with requirements"""
|
19 |
type: str = Field(..., pattern=r"^(llm|tts|stt)$")
|
|
|
41 |
locale_code: str
|
42 |
example: str
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
class UserConfig(BaseModel):
|
45 |
username: str
|
46 |
password_hash: str
|
47 |
salt: str
|
48 |
|
|
|
49 |
class RetryConfig(BaseModel):
|
50 |
retry_count: int = Field(3, alias="max_attempts")
|
51 |
backoff_seconds: int = 2
|
|
|
55 |
enabled: bool = True
|
56 |
url: HttpUrl
|
57 |
|
|
|
58 |
class APIAuthConfig(BaseModel):
|
59 |
enabled: bool = False
|
60 |
token_endpoint: Optional[HttpUrl] = None
|
|
|
90 |
extra = "allow"
|
91 |
populate_by_name = True
|
92 |
|
93 |
+
class LLMConfig(BaseModel):
|
94 |
+
repo_id: str
|
95 |
+
generation_config: Dict[str, Any] = {}
|
96 |
+
use_fine_tune: bool = False
|
97 |
+
fine_tune_zip: str = ""
|
98 |
+
|
99 |
+
class VersionConfig(BaseModel):
|
100 |
+
id: int = Field(..., alias="version_number")
|
101 |
+
no: Optional[int] = None
|
102 |
+
caption: Optional[str] = ""
|
103 |
+
description: Optional[str] = ""
|
104 |
+
published: bool = False
|
105 |
+
deleted: bool = False
|
106 |
+
created_date: Optional[str] = None
|
107 |
+
created_by: Optional[str] = None
|
108 |
+
last_update_date: Optional[str] = None
|
109 |
+
last_update_user: Optional[str] = None
|
110 |
+
publish_date: Optional[str] = None
|
111 |
+
published_by: Optional[str] = None
|
112 |
+
general_prompt: str
|
113 |
+
welcome_prompt: Optional[str] = None
|
114 |
+
llm: LLMConfig
|
115 |
+
intents: List[IntentConfig]
|
116 |
+
|
117 |
+
class Config:
|
118 |
+
extra = "allow"
|
119 |
+
populate_by_name = True
|
120 |
+
|
121 |
+
class ProjectConfig(BaseModel):
|
122 |
+
id: Optional[int] = None
|
123 |
+
name: str
|
124 |
+
caption: Optional[str] = ""
|
125 |
+
icon: Optional[str] = "folder"
|
126 |
+
description: Optional[str] = ""
|
127 |
+
enabled: bool = True
|
128 |
+
last_version_number: Optional[int] = None
|
129 |
+
version_id_counter: int = 1
|
130 |
+
versions: List[VersionConfig]
|
131 |
+
# Language settings - changed from default_language/supported_languages
|
132 |
+
default_locale: str = "tr"
|
133 |
+
supported_locales: List[str] = ["tr"]
|
134 |
+
timezone: Optional[str] = "Europe/Istanbul"
|
135 |
+
region: Optional[str] = "tr-TR"
|
136 |
+
deleted: bool = False
|
137 |
+
created_date: Optional[str] = None
|
138 |
+
created_by: Optional[str] = None
|
139 |
+
last_update_date: Optional[str] = None
|
140 |
+
last_update_user: Optional[str] = None
|
141 |
+
|
142 |
+
class Config:
|
143 |
+
extra = "allow"
|
144 |
+
|
145 |
+
class ActivityLogEntry(BaseModel):
|
146 |
+
timestamp: str
|
147 |
+
username: str
|
148 |
+
action: str
|
149 |
+
entity_type: str
|
150 |
+
entity_id: Optional[int] = None
|
151 |
+
entity_name: Optional[str] = None
|
152 |
+
details: Optional[str] = None
|
153 |
+
|
154 |
+
class ParameterCollectionConfig(BaseModel):
|
155 |
+
"""Configuration for smart parameter collection"""
|
156 |
+
max_params_per_question: int = Field(default=2, ge=1, le=5)
|
157 |
+
retry_unanswered: bool = Field(default=True)
|
158 |
+
collection_prompt: str = Field(default="""
|
159 |
+
You are a helpful assistant collecting information from the user.
|
160 |
+
|
161 |
+
Conversation context:
|
162 |
+
{{conversation_history}}
|
163 |
+
|
164 |
+
Intent: {{intent_name}} - {{intent_caption}}
|
165 |
+
|
166 |
+
Already collected:
|
167 |
+
{{collected_params}}
|
168 |
+
|
169 |
+
Still needed:
|
170 |
+
{{missing_params}}
|
171 |
+
|
172 |
+
Previously asked but not answered:
|
173 |
+
{{unanswered_params}}
|
174 |
+
|
175 |
+
Rules:
|
176 |
+
1. Ask for maximum {{max_params}} parameters in one question
|
177 |
+
2. Group parameters that naturally go together (like from/to cities, dates)
|
178 |
+
3. If some parameters were asked before but not answered, include them again
|
179 |
+
4. Be natural and conversational in {{project_language}}
|
180 |
+
5. Use context from the conversation to make the question flow naturally
|
181 |
+
|
182 |
+
Generate ONLY the question, nothing else.""")
|
183 |
+
|
184 |
+
class Config:
|
185 |
+
extra = "allow"
|
186 |
|
|
|
187 |
class ParameterConfig(BaseModel):
|
188 |
name: str
|
189 |
caption: List[LocalizedCaption] = [] # Multi-language captions
|
|
|
227 |
name: str
|
228 |
caption: Optional[str] = ""
|
229 |
dependencies: List[str] = []
|
230 |
+
requiresApproval: bool = False
|
231 |
examples: List[LocalizedExample] = []
|
232 |
detection_prompt: Optional[str] = None
|
233 |
parameters: List[ParameterConfig] = []
|
|
|
255 |
# Return all examples if no locale match
|
256 |
return [e.example for e in self.examples]
|
257 |
|
258 |
+
# ===================== Global Configuration =====================
|
259 |
+
class GlobalConfig(BaseModel):
|
260 |
+
# Provider settings (replaces work_mode, cloud_token, spark_endpoint)
|
261 |
+
llm_provider: ProviderSettings
|
262 |
+
tts_provider: ProviderSettings = Field(default_factory=lambda: ProviderSettings(name="no_tts"))
|
263 |
+
stt_provider: ProviderSettings = Field(default_factory=lambda: ProviderSettings(name="no_stt"))
|
264 |
+
|
265 |
+
# Available providers
|
266 |
+
providers: List[ProviderConfig] = []
|
267 |
+
|
268 |
+
# User management
|
269 |
+
users: List["UserConfig"] = []
|
270 |
+
|
271 |
+
# Helper methods for providers
|
272 |
+
def get_provider_config(self, provider_type: str, provider_name: str) -> Optional[ProviderConfig]:
|
273 |
+
"""Get provider configuration by type and name"""
|
274 |
+
return next(
|
275 |
+
(p for p in self.providers if p.type == provider_type and p.name == provider_name),
|
276 |
+
None
|
277 |
+
)
|
278 |
+
|
279 |
+
def get_providers_by_type(self, provider_type: str) -> List[ProviderConfig]:
|
280 |
+
"""Get all providers of a specific type"""
|
281 |
+
return [p for p in self.providers if p.type == provider_type]
|
282 |
+
|
283 |
+
def get_plain_api_key(self, provider_type: str) -> Optional[str]:
|
284 |
+
"""Get decrypted API key for a provider type"""
|
285 |
+
provider_map = {
|
286 |
+
"llm": self.llm_provider,
|
287 |
+
"tts": self.tts_provider,
|
288 |
+
"stt": self.stt_provider
|
289 |
+
}
|
290 |
+
provider = provider_map.get(provider_type)
|
291 |
+
if provider and provider.api_key:
|
292 |
+
if provider.api_key.startswith("enc:"):
|
293 |
+
return decrypt(provider.api_key)
|
294 |
+
return provider.api_key
|
295 |
+
return None
|
296 |
+
|
297 |
+
# Backward compatibility helpers
|
298 |
+
def is_cloud_mode(self) -> bool:
|
299 |
+
"""Check if running in cloud mode (HuggingFace)"""
|
300 |
+
return bool(os.environ.get("SPACE_ID"))
|
301 |
+
|
302 |
+
def is_gpt_mode(self) -> bool:
|
303 |
+
"""Check if using GPT provider"""
|
304 |
+
return self.llm_provider.name.startswith("gpt4o") if self.llm_provider else False
|
305 |
+
|
306 |
+
def get_gpt_model(self) -> str:
|
307 |
+
"""Get the GPT model name for OpenAI API"""
|
308 |
+
if self.llm_provider.name == "gpt4o":
|
309 |
+
return "gpt-4o"
|
310 |
+
elif self.llm_provider.name == "gpt4o-mini":
|
311 |
+
return "gpt-4o-mini"
|
312 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
# ---------------- Service Config ---------
|
315 |
class ServiceConfig(BaseModel):
|
|
|
549 |
"internal_prompt": old_config.get('internal_prompt', ''),
|
550 |
"parameter_collection_config": old_config.get('parameter_collection_config', {
|
551 |
"max_params_per_question": 2,
|
|
|
552 |
"retry_unanswered": True,
|
553 |
"collection_prompt": ParameterCollectionConfig().collection_prompt
|
554 |
})
|