Spaces:
Running
Running
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from typing import List, Dict, Optional, Union, Literal | |
import os | |
# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题 | |
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1" | |
os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
# 如果 torch._dynamo 可用,禁用它 | |
try: | |
import torch._dynamo | |
torch._dynamo.config.disable = True | |
torch._dynamo.config.suppress_errors = True | |
except ImportError: | |
pass | |
from .llm_base import TransformersBaseChatCompletion | |
class GemmaTransformersChatCompletion(TransformersBaseChatCompletion): | |
"""基于 Transformers 库的 Gemma 聊天完成实现""" | |
def __init__( | |
self, | |
model_name: str = "google/gemma-3-4b-it", | |
device_map: Optional[str] = None, | |
device: Optional[str] = None, | |
): | |
# Gemma 使用 float16 作为默认数据类型 | |
super().__init__( | |
model_name=model_name, | |
device_map=device_map, | |
device=device, | |
) | |
def _print_error_hints(self): | |
"""打印Gemma特定的错误提示信息""" | |
super()._print_error_hints() | |
print("Gemma 特殊要求:") | |
print("- 建议使用 Transformers >= 4.21.0") | |
print("- 推荐使用 float16 数据类型") | |
print("- 确保有足够的GPU内存") | |
# 为了保持向后兼容性,也可以提供一个简化的工厂函数 | |
def create_gemma_transformers_client( | |
model_name: str = "google/gemma-3-4b-it", | |
device: Optional[str] = None, | |
**kwargs | |
) -> GemmaTransformersChatCompletion: | |
""" | |
创建 Gemma Transformers 客户端的工厂函数 | |
Args: | |
model_name: 模型名称 | |
device: 指定设备 ("cpu", "cuda", "mps", 等) | |
**kwargs: 其他传递给构造函数的参数 | |
Returns: | |
GemmaTransformersChatCompletion 实例 | |
""" | |
return GemmaTransformersChatCompletion( | |
model_name=model_name, | |
device=device, | |
**kwargs | |
) | |