Spaces:
Running
Running
File size: 2,044 Bytes
8289369 ae20fe2 8289369 c8eae1a d709cdc 8289369 c8eae1a 8289369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from typing import List, Dict, Optional, Union, Literal
import os
# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
os.environ["TORCH_COMPILE_DISABLE"] = "1"
# 如果 torch._dynamo 可用,禁用它
try:
import torch._dynamo
torch._dynamo.config.disable = True
torch._dynamo.config.suppress_errors = True
except ImportError:
pass
from .llm_base import TransformersBaseChatCompletion
class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
"""基于 Transformers 库的 Gemma 聊天完成实现"""
def __init__(
self,
model_name: str = "google/gemma-3-4b-it",
device_map: Optional[str] = None,
device: Optional[str] = None,
):
# Gemma 使用 float16 作为默认数据类型
super().__init__(
model_name=model_name,
device_map=device_map,
device=device,
)
def _print_error_hints(self):
"""打印Gemma特定的错误提示信息"""
super()._print_error_hints()
print("Gemma 特殊要求:")
print("- 建议使用 Transformers >= 4.21.0")
print("- 推荐使用 float16 数据类型")
print("- 确保有足够的GPU内存")
# 为了保持向后兼容性,也可以提供一个简化的工厂函数
def create_gemma_transformers_client(
model_name: str = "google/gemma-3-4b-it",
device: Optional[str] = None,
**kwargs
) -> GemmaTransformersChatCompletion:
"""
创建 Gemma Transformers 客户端的工厂函数
Args:
model_name: 模型名称
device: 指定设备 ("cpu", "cuda", "mps", 等)
**kwargs: 其他传递给构造函数的参数
Returns:
GemmaTransformersChatCompletion 实例
"""
return GemmaTransformersChatCompletion(
model_name=model_name,
device=device,
**kwargs
)
|