konieshadow commited on
Commit
b88a2d6
·
1 Parent(s): 642af4d
src/podcast_transcribe/llm/llm_base.py CHANGED
@@ -178,6 +178,7 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
178
  device: Optional[str] = None,
179
  ):
180
  super().__init__(model_name)
 
181
  self.device_map = device_map
182
  self.device = device
183
 
 
178
  device: Optional[str] = None,
179
  ):
180
  super().__init__(model_name)
181
+ torch.set_float32_matmul_precision('high') # 设置 TensorFloat32 精度
182
  self.device_map = device_map
183
  self.device = device
184
 
src/podcast_transcribe/llm/llm_router.py CHANGED
@@ -38,7 +38,7 @@ class LLMRouter:
38
  "class_name": "GemmaTransformersChatCompletion",
39
  "default_model": "google/gemma-3-4b-it",
40
  "supported_params": [
41
- "model_name", "device_map",
42
  ],
43
  "description": "基于Transformers库的Gemma聊天完成实现"
44
  }
 
38
  "class_name": "GemmaTransformersChatCompletion",
39
  "default_model": "google/gemma-3-4b-it",
40
  "supported_params": [
41
+ "model_name", "device_map", "device"
42
  ],
43
  "description": "基于Transformers库的Gemma聊天完成实现"
44
  }