File size: 2,259 Bytes
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
5d8d9c8
 
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d8d9c8
 
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Dict, Union
from .llm_base import BaseChatCompletion


class GemmaMLXChatCompletion(BaseChatCompletion):
    """基于 MLX 库的 Gemma 聊天完成实现"""
    
    def __init__(self, model_name: str = "mlx-community/gemma-3-12b-it-4bit-DWQ"):
        super().__init__(model_name)
        self._load_model_and_tokenizer()

    def _load_model_and_tokenizer(self):
        """加载 MLX 模型和分词器"""
        try:
            from mlx_lm import load

            print(f"正在加载 MLX 模型: {self.model_name}")
            self.model, self.tokenizer = load(self.model_name)
            print(f"MLX 模型 {self.model_name} 加载成功")
        except Exception as e:
            print(f"加载模型 {self.model_name} 时出错: {e}")
            print("请确保模型名称正确且可访问。")
            print("您可以尝试使用 'mlx_lm.utils.get_model_path(model_name)' 搜索可用的模型。")
            raise

    def _generate_response(
        self,
        prompt_str: str,
        temperature: float,
        max_tokens: int,
        top_p: float,
        **kwargs
    ) -> str:
        """使用 MLX 生成响应"""
        from mlx_lm import load, generate
        from mlx_lm.sample_utils import make_sampler
        
        # 为temperature和top_p创建一个采样器
        sampler = make_sampler(temp=temperature, top_p=top_p)

        # 生成响应
        # mlx_lm中的`generate`函数接受模型、分词器、提示和其他生成参数。
        # 我们需要将我们的参数映射到`generate`期望的参数。
        # `mlx_lm.generate` 的 verbose 参数可用于调试。
        # `temperature` 是 `mlx_lm.generate` 中温度的参数名称。
        response_text = generate(
            self.model,
            self.tokenizer,
            prompt=prompt_str,
            max_tokens=max_tokens,
            sampler=sampler,
            # verbose=True # 取消注释以调试生成过程
        )
        
        return response_text

    def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
        """获取模型信息"""
        return {
            "model_name": self.model_name,
            "model_type": "mlx",
            "library": "mlx_lm"
        }