File size: 2,044 Bytes
8289369
 
 
ae20fe2
 
 
 
 
 
 
 
 
 
 
 
 
 
8289369
 
 
 
 
 
 
 
c8eae1a
d709cdc
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8eae1a
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from typing import List, Dict, Optional, Union, Literal
import os

# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
os.environ["TORCH_COMPILE_DISABLE"] = "1"

# 如果 torch._dynamo 可用,禁用它
try:
    import torch._dynamo
    torch._dynamo.config.disable = True
    torch._dynamo.config.suppress_errors = True
except ImportError:
    pass

from .llm_base import TransformersBaseChatCompletion


class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
    """基于 Transformers 库的 Gemma 聊天完成实现"""
    
    def __init__(
        self,
        model_name: str = "google/gemma-3-4b-it",
        device_map: Optional[str] = None,
        device: Optional[str] = None,
    ):
        # Gemma 使用 float16 作为默认数据类型
        super().__init__(
            model_name=model_name,
            device_map=device_map,
            device=device,
        )
    
    def _print_error_hints(self):
        """打印Gemma特定的错误提示信息"""
        super()._print_error_hints()
        print("Gemma 特殊要求:")
        print("- 建议使用 Transformers >= 4.21.0")
        print("- 推荐使用 float16 数据类型")
        print("- 确保有足够的GPU内存")


# 为了保持向后兼容性,也可以提供一个简化的工厂函数
def create_gemma_transformers_client(
    model_name: str = "google/gemma-3-4b-it",
    device: Optional[str] = None,
    **kwargs
) -> GemmaTransformersChatCompletion:
    """
    创建 Gemma Transformers 客户端的工厂函数
    
    Args:
        model_name: 模型名称
        device: 指定设备 ("cpu", "cuda", "mps", 等)
        **kwargs: 其他传递给构造函数的参数
    
    Returns:
        GemmaTransformersChatCompletion 实例
    """
    return GemmaTransformersChatCompletion(
        model_name=model_name,
        device=device,
        **kwargs
    )