podcast-transcriber / src /podcast_transcribe /llm /llm_gemma_transfomers.py
konieshadow's picture
添加对PyTorch编译的禁用支持,以解决Gradio Spaces中的兼容性问题,并在多个文件中统一配置日志记录。
ae20fe2
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from typing import List, Dict, Optional, Union, Literal
import os
# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
os.environ["TORCH_COMPILE_DISABLE"] = "1"
# 如果 torch._dynamo 可用,禁用它
try:
import torch._dynamo
torch._dynamo.config.disable = True
torch._dynamo.config.suppress_errors = True
except ImportError:
pass
from .llm_base import TransformersBaseChatCompletion
class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
"""基于 Transformers 库的 Gemma 聊天完成实现"""
def __init__(
self,
model_name: str = "google/gemma-3-4b-it",
device_map: Optional[str] = None,
device: Optional[str] = None,
):
# Gemma 使用 float16 作为默认数据类型
super().__init__(
model_name=model_name,
device_map=device_map,
device=device,
)
def _print_error_hints(self):
"""打印Gemma特定的错误提示信息"""
super()._print_error_hints()
print("Gemma 特殊要求:")
print("- 建议使用 Transformers >= 4.21.0")
print("- 推荐使用 float16 数据类型")
print("- 确保有足够的GPU内存")
# 为了保持向后兼容性,也可以提供一个简化的工厂函数
def create_gemma_transformers_client(
model_name: str = "google/gemma-3-4b-it",
device: Optional[str] = None,
**kwargs
) -> GemmaTransformersChatCompletion:
"""
创建 Gemma Transformers 客户端的工厂函数
Args:
model_name: 模型名称
device: 指定设备 ("cpu", "cuda", "mps", 等)
**kwargs: 其他传递给构造函数的参数
Returns:
GemmaTransformersChatCompletion 实例
"""
return GemmaTransformersChatCompletion(
model_name=model_name,
device=device,
**kwargs
)