Spaces:
Running
Running
Commit
·
089bc3b
1
Parent(s):
5a751c2
移除推理完成接口及相关逻辑,调整GPU持续时间为180秒,简化代码结构以提高可维护性。
Browse files
src/podcast_transcribe/llm/llm_router.py
CHANGED
@@ -226,76 +226,6 @@ class LLMRouter:
|
|
226 |
logger.error(f"使用provider '{provider}' 进行聊天完成失败: {str(e)}", exc_info=True)
|
227 |
raise RuntimeError(f"聊天完成失败: {str(e)}")
|
228 |
|
229 |
-
def reasoning_completion(
|
230 |
-
self,
|
231 |
-
messages: List[Dict[str, str]],
|
232 |
-
provider: str = "gemma-transformers",
|
233 |
-
temperature: float = 0.3,
|
234 |
-
max_tokens: int = 2048,
|
235 |
-
top_p: float = 0.9,
|
236 |
-
model: Optional[str] = None,
|
237 |
-
extract_reasoning_steps: bool = True,
|
238 |
-
**kwargs
|
239 |
-
) -> Dict[str, Any]:
|
240 |
-
"""
|
241 |
-
专门用于推理任务的聊天完成接口
|
242 |
-
|
243 |
-
参数:
|
244 |
-
messages: 消息列表,每个消息包含role和content
|
245 |
-
provider: LLM提供者名称,默认使用gemma-transformers
|
246 |
-
temperature: 温度参数(推理任务建议使用较低值)
|
247 |
-
max_tokens: 最大生成token数
|
248 |
-
top_p: nucleus采样参数
|
249 |
-
model: 可选的模型名称
|
250 |
-
extract_reasoning_steps: 是否提取推理步骤
|
251 |
-
**kwargs: 其他参数
|
252 |
-
|
253 |
-
返回:
|
254 |
-
包含推理步骤的响应字典
|
255 |
-
"""
|
256 |
-
logger.info(f"使用provider '{provider}' 进行推理完成,消息数量: {len(messages)}")
|
257 |
-
|
258 |
-
# 确保使用支持推理的provider
|
259 |
-
if provider not in ["gemma-transformers"]:
|
260 |
-
logger.warning(f"Provider '{provider}' 可能不支持推理功能,建议使用 'gemma-transformers'")
|
261 |
-
|
262 |
-
try:
|
263 |
-
# 如果提供了model参数,添加到kwargs中
|
264 |
-
if model is not None:
|
265 |
-
kwargs["model_name"] = model
|
266 |
-
|
267 |
-
# 获取或创建LLM实例
|
268 |
-
llm_instance = self._get_or_create_instance(provider, **kwargs)
|
269 |
-
|
270 |
-
# 检查实例是否支持推理完成
|
271 |
-
if hasattr(llm_instance, 'reasoning_completion'):
|
272 |
-
result = llm_instance.reasoning_completion(
|
273 |
-
messages=messages,
|
274 |
-
temperature=temperature,
|
275 |
-
max_tokens=max_tokens,
|
276 |
-
top_p=top_p,
|
277 |
-
extract_reasoning_steps=extract_reasoning_steps,
|
278 |
-
**kwargs
|
279 |
-
)
|
280 |
-
else:
|
281 |
-
# 回退到普通聊天完成
|
282 |
-
logger.warning(f"Provider '{provider}' 不支持推理完成,回退到普通聊天完成")
|
283 |
-
result = llm_instance.create(
|
284 |
-
messages=messages,
|
285 |
-
temperature=temperature,
|
286 |
-
max_tokens=max_tokens,
|
287 |
-
top_p=top_p,
|
288 |
-
model=model,
|
289 |
-
**kwargs
|
290 |
-
)
|
291 |
-
|
292 |
-
logger.info(f"推理完成成功,使用tokens: {result.get('usage', {}).get('total_tokens', 'unknown')}")
|
293 |
-
return result
|
294 |
-
|
295 |
-
except Exception as e:
|
296 |
-
logger.error(f"使用provider '{provider}' 进行推理完成失败: {str(e)}", exc_info=True)
|
297 |
-
raise RuntimeError(f"推理完成失败: {str(e)}")
|
298 |
-
|
299 |
def get_model_info(self, provider: str, **kwargs) -> Dict[str, Any]:
|
300 |
"""
|
301 |
获取模型信息
|
@@ -356,7 +286,7 @@ class LLMRouter:
|
|
356 |
# 创建全局路由器实例
|
357 |
_router = LLMRouter()
|
358 |
|
359 |
-
@spaces.GPU(duration=
|
360 |
def chat_completion(
|
361 |
messages: List[Dict[str, str]],
|
362 |
provider: str = "gemma-transformers",
|
@@ -432,72 +362,6 @@ def chat_completion(
|
|
432 |
**params
|
433 |
)
|
434 |
|
435 |
-
@spaces.GPU(duration=60)
|
436 |
-
def reasoning_completion(
|
437 |
-
messages: List[Dict[str, str]],
|
438 |
-
provider: str = "gemma-transformers",
|
439 |
-
temperature: float = 0.3,
|
440 |
-
max_tokens: int = 2048,
|
441 |
-
top_p: float = 0.9,
|
442 |
-
model: Optional[str] = None,
|
443 |
-
device: Optional[str] = None,
|
444 |
-
device_map: Optional[str] = None,
|
445 |
-
extract_reasoning_steps: bool = True,
|
446 |
-
**kwargs
|
447 |
-
) -> Dict[str, Any]:
|
448 |
-
"""
|
449 |
-
专门用于推理任务的聊天完成接口函数
|
450 |
-
|
451 |
-
参数:
|
452 |
-
messages: 消息列表,每个消息包含role和content字段
|
453 |
-
provider: LLM提供者,默认使用gemma-transformers
|
454 |
-
temperature: 温度参数(推理任务建议使用较低值)
|
455 |
-
max_tokens: 最大生成token数
|
456 |
-
top_p: nucleus采样参数
|
457 |
-
model: 模型名称,如果不指定则使用默认模型
|
458 |
-
device: 推理设备
|
459 |
-
device_map: 设备映射配置
|
460 |
-
extract_reasoning_steps: 是否提取推理步骤
|
461 |
-
**kwargs: 其他参数
|
462 |
-
|
463 |
-
返回:
|
464 |
-
包含推理步骤的响应字典
|
465 |
-
|
466 |
-
示例:
|
467 |
-
# 数学推理任务
|
468 |
-
response = reasoning_completion(
|
469 |
-
messages=[{"role": "user", "content": "解这个方程:3x + 7 = 22"}],
|
470 |
-
provider="gemma-transformers",
|
471 |
-
extract_reasoning_steps=True
|
472 |
-
)
|
473 |
-
|
474 |
-
# 逻辑推理任务
|
475 |
-
response = reasoning_completion(
|
476 |
-
messages=[{"role": "user", "content": "如果所有的猫都是动物,而小花是一只猫,那么小花是什么?"}],
|
477 |
-
provider="gemma-transformers",
|
478 |
-
temperature=0.2
|
479 |
-
)
|
480 |
-
"""
|
481 |
-
# 准备参数
|
482 |
-
params = kwargs.copy()
|
483 |
-
if model is not None:
|
484 |
-
params["model_name"] = model
|
485 |
-
if device is not None:
|
486 |
-
params["device"] = device
|
487 |
-
if device_map:
|
488 |
-
params["device_map"] = device_map
|
489 |
-
|
490 |
-
return _router.reasoning_completion(
|
491 |
-
messages=messages,
|
492 |
-
provider=provider,
|
493 |
-
temperature=temperature,
|
494 |
-
max_tokens=max_tokens,
|
495 |
-
top_p=top_p,
|
496 |
-
model=model,
|
497 |
-
extract_reasoning_steps=extract_reasoning_steps,
|
498 |
-
**params
|
499 |
-
)
|
500 |
-
|
501 |
|
502 |
def get_model_info(provider: str = "gemma-mlx", **kwargs) -> Dict[str, Any]:
|
503 |
"""
|
|
|
226 |
logger.error(f"使用provider '{provider}' 进行聊天完成失败: {str(e)}", exc_info=True)
|
227 |
raise RuntimeError(f"聊天完成失败: {str(e)}")
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
def get_model_info(self, provider: str, **kwargs) -> Dict[str, Any]:
|
230 |
"""
|
231 |
获取模型信息
|
|
|
286 |
# 创建全局路由器实例
|
287 |
_router = LLMRouter()
|
288 |
|
289 |
+
@spaces.GPU(duration=180)
|
290 |
def chat_completion(
|
291 |
messages: List[Dict[str, str]],
|
292 |
provider: str = "gemma-transformers",
|
|
|
362 |
**params
|
363 |
)
|
364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
|
366 |
def get_model_info(provider: str = "gemma-mlx", **kwargs) -> Dict[str, Any]:
|
367 |
"""
|