Spaces:
Running
Running
from typing import Dict, Union | |
from .llm_base import BaseChatCompletion | |
class GemmaMLXChatCompletion(BaseChatCompletion): | |
"""基于 MLX 库的 Gemma 聊天完成实现""" | |
def __init__(self, model_name: str = "mlx-community/gemma-3-12b-it-4bit-DWQ"): | |
super().__init__(model_name) | |
self._load_model_and_tokenizer() | |
def _load_model_and_tokenizer(self): | |
"""加载 MLX 模型和分词器""" | |
try: | |
from mlx_lm import load | |
print(f"正在加载 MLX 模型: {self.model_name}") | |
self.model, self.tokenizer = load(self.model_name) | |
print(f"MLX 模型 {self.model_name} 加载成功") | |
except Exception as e: | |
print(f"加载模型 {self.model_name} 时出错: {e}") | |
print("请确保模型名称正确且可访问。") | |
print("您可以尝试使用 'mlx_lm.utils.get_model_path(model_name)' 搜索可用的模型。") | |
raise | |
def _generate_response( | |
self, | |
prompt_str: str, | |
temperature: float, | |
max_tokens: int, | |
top_p: float, | |
**kwargs | |
) -> str: | |
"""使用 MLX 生成响应""" | |
from mlx_lm import load, generate | |
from mlx_lm.sample_utils import make_sampler | |
# 为temperature和top_p创建一个采样器 | |
sampler = make_sampler(temp=temperature, top_p=top_p) | |
# 生成响应 | |
# mlx_lm中的`generate`函数接受模型、分词器、提示和其他生成参数。 | |
# 我们需要将我们的参数映射到`generate`期望的参数。 | |
# `mlx_lm.generate` 的 verbose 参数可用于调试。 | |
# `temperature` 是 `mlx_lm.generate` 中温度的参数名称。 | |
response_text = generate( | |
self.model, | |
self.tokenizer, | |
prompt=prompt_str, | |
max_tokens=max_tokens, | |
sampler=sampler, | |
# verbose=True # 取消注释以调试生成过程 | |
) | |
return response_text | |
def get_model_info(self) -> Dict[str, Union[str, bool, int]]: | |
"""获取模型信息""" | |
return { | |
"model_name": self.model_name, | |
"model_type": "mlx", | |
"library": "mlx_lm" | |
} | |