File size: 13,524 Bytes
8289369
 
 
 
 
 
924aa01
8289369
ae20fe2
 
 
 
 
 
 
 
 
 
 
 
 
7803eb5
 
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8eae1a
8289369
b88a2d6
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642af4d
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
089bc3b
8289369
 
924aa01
8289369
 
 
 
 
d709cdc
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8eae1a
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d709cdc
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
"""
LLM模型调用路由器
根据传递的provider参数调用不同的LLM实现,支持延迟加载
"""

import logging
import torch
from typing import Dict, Any, Optional, List, Union
import os

# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
os.environ["TORCH_COMPILE_DISABLE"] = "1"

# 如果 torch._dynamo 可用,禁用它
try:
    import torch._dynamo
    torch._dynamo.config.disable = True
    torch._dynamo.config.suppress_errors = True
except ImportError:
    pass

import spaces
from .llm_base import BaseChatCompletion
from . import llm_gemma_mlx
from . import llm_gemma_transfomers

# 配置日志
logger = logging.getLogger("llm")


class LLMRouter:
    """LLM模型调用路由器,支持多种实现的统一调用"""
    
    def __init__(self):
        """初始化路由器"""
        self._loaded_modules = {}  # 用于缓存已加载的模块
        self._llm_instances = {}   # 用于缓存已实例化的LLM实例
        
        # 定义支持的provider配置
        self._provider_configs = {
            "gemma-mlx": {
                "module_path": "llm_gemma_mlx",
                "class_name": "GemmaMLXChatCompletion",
                "default_model": "mlx-community/gemma-3-12b-it-4bit-DWQ",
                "supported_params": ["model_name"],
                "description": "基于MLX库的Gemma聊天完成实现"
            },
            "gemma-transformers": {
                "module_path": "llm_gemma_transfomers",
                "class_name": "GemmaTransformersChatCompletion",
                "default_model": "google/gemma-3-4b-it",
                "supported_params": [
                    "model_name", "device_map", "device"
                ],
                "description": "基于Transformers库的Gemma聊天完成实现"
            }
        }
    
    def _lazy_load_module(self, provider: str):
        """
        获取指定provider的模块
        
        参数:
            provider: provider名称
            
        返回:
            对应的模块
        """
        if provider not in self._provider_configs:
            raise ValueError(f"不支持的provider: {provider}")
            
        if provider not in self._loaded_modules:
            module_path = self._provider_configs[provider]["module_path"]
            logger.info(f"获取模块: {module_path}")
            
            # 根据module_path返回对应的模块
            if module_path == "llm_gemma_mlx":
                module = llm_gemma_mlx
            elif module_path == "llm_gemma_transfomers":
                module = llm_gemma_transfomers
            else:
                raise ImportError(f"未找到模块: {module_path}")
            
            self._loaded_modules[provider] = module
            logger.info(f"模块 {module_path} 获取成功")
        
        return self._loaded_modules[provider]
    
    def _get_llm_class(self, provider: str):
        """
        获取指定provider的LLM类
        
        参数:
            provider: provider名称
            
        返回:
            LLM类
        """
        module = self._lazy_load_module(provider)
        class_name = self._provider_configs[provider]["class_name"]
        
        if not hasattr(module, class_name):
            raise AttributeError(f"模块中未找到类: {class_name}")
            
        return getattr(module, class_name)
    
    def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
        """
        过滤参数,只保留指定provider支持的参数
        
        参数:
            provider: provider名称
            params: 原始参数字典
            
        返回:
            过滤后的参数字典
        """
        supported_params = self._provider_configs[provider]["supported_params"]
        filtered_params = {}
        
        for param in supported_params:
            if param in params:
                filtered_params[param] = params[param]
        
        # 如果没有指定model_name,使用默认模型
        if "model_name" not in filtered_params and "model_name" in supported_params:
            filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
        
        return filtered_params
    
    def _get_instance_key(self, provider: str, params: Dict[str, Any]) -> str:
        """
        生成LLM实例的缓存键
        
        参数:
            provider: provider名称
            params: 参数字典
            
        返回:
            实例缓存键
        """
        # 将参数转换为可哈希的字符串
        param_str = "_".join([f"{k}={v}" for k, v in sorted(params.items())])
        return f"{provider}_{param_str}"
    
    def _get_or_create_instance(self, provider: str, **kwargs) -> BaseChatCompletion:
        """
        获取或创建LLM实例(支持缓存复用)
        
        参数:
            provider: provider名称
            **kwargs: 构造函数参数
            
        返回:
            LLM实例
        """
        # 过滤并准备参数
        filtered_kwargs = self._filter_params(provider, kwargs)
        
        # 生成实例缓存键
        instance_key = self._get_instance_key(provider, filtered_kwargs)
        
        # 检查是否已有缓存实例
        if instance_key not in self._llm_instances:
            try:
                # 获取LLM类
                llm_class = self._get_llm_class(provider)
                
                logger.debug(f"创建 {provider} LLM实例,参数: {filtered_kwargs}")
                
                # 创建实例
                instance = llm_class(**filtered_kwargs)
                
                # 缓存实例
                self._llm_instances[instance_key] = instance
                
                logger.info(f"LLM实例创建成功: {provider} ({instance.model_name})")
                
            except Exception as e:
                logger.error(f"创建 {provider} LLM实例失败: {str(e)}", exc_info=True)
                raise RuntimeError(f"创建LLM实例失败: {str(e)}")
        
        return self._llm_instances[instance_key]
    
    def chat_completion(
        self,
        messages: List[Dict[str, str]],
        provider: str,
        temperature: float = 0.7,
        max_tokens: int = 2048,
        top_p: float = 1.0,
        model: Optional[str] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """
        统一的聊天完成接口
        
        参数:
            messages: 消息列表,每个消息包含role和content
            provider: LLM提供者名称
            temperature: 温度参数,控制生成的随机性
            max_tokens: 最大生成token数
            top_p: nucleus采样参数
            model: 可选的模型名称,如果提供则覆盖默认model_name
            **kwargs: 其他参数,如device等
            
        返回:
            聊天完成响应字典
        """
        logger.info(f"使用provider '{provider}' 进行聊天完成,消息数量: {len(messages)}")
        
        if provider not in self._provider_configs:
            available_providers = list(self._provider_configs.keys())
            raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
        
        try:
            # 如果提供了model参数,添加到kwargs中
            if model is not None:
                kwargs["model_name"] = model
            
            # 获取或创建LLM实例
            llm_instance = self._get_or_create_instance(provider, **kwargs)
            
            # 调用聊天完成
            result = llm_instance.create(
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                model=model,
                **kwargs
            )
            
            logger.info(f"聊天完成成功,使用tokens: {result.get('usage', {}).get('total_tokens', 'unknown')}")
            return result
            
        except Exception as e:
            logger.error(f"使用provider '{provider}' 进行聊天完成失败: {str(e)}", exc_info=True)
            raise RuntimeError(f"聊天完成失败: {str(e)}")
    
    def get_model_info(self, provider: str, **kwargs) -> Dict[str, Any]:
        """
        获取模型信息
        
        参数:
            provider: provider名称
            **kwargs: 构造函数参数
            
        返回:
            模型信息字典
        """
        try:
            llm_instance = self._get_or_create_instance(provider, **kwargs)
            return llm_instance.get_model_info()
        except Exception as e:
            logger.error(f"获取模型信息失败: {str(e)}")
            raise RuntimeError(f"获取模型信息失败: {str(e)}")
    
    def get_available_providers(self) -> Dict[str, str]:
        """
        获取所有可用的provider及其描述
        
        返回:
            provider名称到描述的映射
        """
        return {
            provider: config["description"] 
            for provider, config in self._provider_configs.items()
        }
    
    def get_provider_info(self, provider: str) -> Dict[str, Any]:
        """
        获取指定provider的详细信息
        
        参数:
            provider: provider名称
            
        返回:
            provider的配置信息
        """
        if provider not in self._provider_configs:
            raise ValueError(f"不支持的provider: {provider}")
            
        return self._provider_configs[provider].copy()
    
    def clear_cache(self):
        """清理缓存的实例"""
        # 清理每个实例的GPU缓存
        for instance in self._llm_instances.values():
            if hasattr(instance, 'clear_cache'):
                instance.clear_cache()
        
        # 清理实例缓存
        self._llm_instances.clear()
        logger.info("LLM实例缓存已清理")


# 创建全局路由器实例
_router = LLMRouter()

@spaces.GPU(duration=180)
def chat_completion(
    messages: List[Dict[str, str]],
    provider: str = "gemma-transformers",
    temperature: float = 0.7,
    max_tokens: int = 2048,
    top_p: float = 1.0,
    model: Optional[str] = None,
    device: Optional[str] = None,
    device_map: Optional[str] = None,
    **kwargs
) -> Dict[str, Any]:
    """
    统一的聊天完成接口函数
    
    参数:
        messages: 消息列表,每个消息包含role和content字段
        provider: LLM提供者,可选值:
            - "gemma-mlx": 基于MLX库的Gemma聊天完成实现
            - "gemma-transformers": 基于Transformers库的Gemma聊天完成实现
        temperature: 温度参数,控制生成的随机性 (0.0-2.0)
        max_tokens: 最大生成token数
        top_p: nucleus采样参数 (0.0-1.0)
        model: 模型名称,如果不指定则使用默认模型
        device: 推理设备,'cpu'、'cuda'、'mps'(仅transformers provider支持)
        device_map: 设备映射配置(仅transformers provider支持)
        **kwargs: 其他参数
        
    返回:
        聊天完成响应字典,包含生成的消息和使用统计
        
    示例:
        # 使用默认MLX实现
        response = chat_completion(
            messages=[{"role": "user", "content": "你好"}],
            provider="gemma-mlx"
        )
        
        # 使用Gemma transformers实现
        response = chat_completion(
            messages=[{"role": "user", "content": "你好"}],
            provider="gemma-transformers",
            model="google/gemma-3-4b-it",
            device="cuda",
        )
        
        # 自定义参数
        response = chat_completion(
            messages=[
                {"role": "system", "content": "你是一个有用的助手"},
                {"role": "user", "content": "请介绍自己"}
            ],
            provider="gemma-mlx",
            temperature=0.8,
            max_tokens=1024
        )
    """
    # 准备参数
    params = kwargs.copy()
    if model is not None:
        params["model_name"] = model
    if device is not None:
        params["device"] = device
    if device_map:
        params["device_map"] = device_map
    
    return _router.chat_completion(
        messages=messages,
        provider=provider,
        temperature=temperature,
        max_tokens=max_tokens,
        top_p=top_p,
        model=model,
        **params
    )


def get_model_info(provider: str = "gemma-mlx", **kwargs) -> Dict[str, Any]:
    """
    获取模型信息
    
    参数:
        provider: provider名称
        **kwargs: 构造函数参数
        
    返回:
        模型信息字典
    """
    return _router.get_model_info(provider, **kwargs)


def get_available_providers() -> Dict[str, str]:
    """
    获取所有可用的LLM提供者
    
    返回:
        provider名称到描述的映射
    """
    return _router.get_available_providers()


def get_provider_info(provider: str) -> Dict[str, Any]:
    """
    获取指定provider的详细信息
    
    参数:
        provider: provider名称
        
    返回:
        provider的配置信息
    """
    return _router.get_provider_info(provider)


def clear_cache():
    """清理缓存的LLM实例"""
    _router.clear_cache()