konieshadow commited on
Commit
d709cdc
·
1 Parent(s): 814fa89

更新模型初始化,添加设备参数支持,并将device_map默认值修改为None,以提高灵活性和兼容性。

Browse files
examples/simple_llm.py CHANGED
@@ -16,15 +16,16 @@ if __name__ == "__main__":
16
  try:
17
  model_name = "google/gemma-3-4b-it"
18
  use_4bit_quantization = False
 
19
 
20
  # gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
21
  # 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
22
  if model_name.startswith("mlx-community"):
23
  gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
24
  elif model_name.startswith("microsoft"):
25
- gemma_chat = Phi4TransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization)
26
  else:
27
- gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization)
28
 
29
  print("\n--- 示例 1: 简单用户查询 ---")
30
  messages_example1 = [
 
16
  try:
17
  model_name = "google/gemma-3-4b-it"
18
  use_4bit_quantization = False
19
+ device = "mps"
20
 
21
  # gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
22
  # 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
23
  if model_name.startswith("mlx-community"):
24
  gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
25
  elif model_name.startswith("microsoft"):
26
+ gemma_chat = Phi4TransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device)
27
  else:
28
+ gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device)
29
 
30
  print("\n--- 示例 1: 简单用户查询 ---")
31
  messages_example1 = [
src/podcast_transcribe/llm/llm_base.py CHANGED
@@ -182,7 +182,7 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
182
  self,
183
  model_name: str,
184
  use_4bit_quantization: bool = False,
185
- device_map: Optional[str] = "auto",
186
  device: Optional[str] = None,
187
  trust_remote_code: bool = True,
188
  torch_dtype: Optional[torch.dtype] = None
 
182
  self,
183
  model_name: str,
184
  use_4bit_quantization: bool = False,
185
+ device_map: Optional[str] = None,
186
  device: Optional[str] = None,
187
  trust_remote_code: bool = True,
188
  torch_dtype: Optional[torch.dtype] = None
src/podcast_transcribe/llm/llm_gemma_transfomers.py CHANGED
@@ -11,7 +11,7 @@ class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
11
  self,
12
  model_name: str = "google/gemma-3-4b-it",
13
  use_4bit_quantization: bool = False,
14
- device_map: Optional[str] = "auto",
15
  device: Optional[str] = None,
16
  trust_remote_code: bool = True
17
  ):
 
11
  self,
12
  model_name: str = "google/gemma-3-4b-it",
13
  use_4bit_quantization: bool = False,
14
+ device_map: Optional[str] = None,
15
  device: Optional[str] = None,
16
  trust_remote_code: bool = True
17
  ):
src/podcast_transcribe/llm/llm_phi4_transfomers.py CHANGED
@@ -11,7 +11,7 @@ class Phi4TransformersChatCompletion(TransformersBaseChatCompletion):
11
  self,
12
  model_name: str = "microsoft/Phi-4-mini-reasoning",
13
  use_4bit_quantization: bool = False,
14
- device_map: Optional[str] = "auto",
15
  device: Optional[str] = None,
16
  trust_remote_code: bool = True
17
  ):
 
11
  self,
12
  model_name: str = "microsoft/Phi-4-mini-reasoning",
13
  use_4bit_quantization: bool = False,
14
+ device_map: Optional[str] = None,
15
  device: Optional[str] = None,
16
  trust_remote_code: bool = True
17
  ):
src/podcast_transcribe/llm/llm_router.py CHANGED
@@ -379,7 +379,7 @@ def chat_completion(
379
  model: Optional[str] = None,
380
  device: Optional[str] = None,
381
  use_4bit_quantization: bool = False,
382
- device_map: Optional[str] = "auto",
383
  trust_remote_code: bool = True,
384
  **kwargs
385
  ) -> Dict[str, Any]:
@@ -448,7 +448,7 @@ def chat_completion(
448
  params["device"] = device
449
  if use_4bit_quantization:
450
  params["use_4bit_quantization"] = use_4bit_quantization
451
- if device_map != "auto":
452
  params["device_map"] = device_map
453
  if not trust_remote_code:
454
  params["trust_remote_code"] = trust_remote_code
@@ -473,7 +473,7 @@ def reasoning_completion(
473
  model: Optional[str] = None,
474
  device: Optional[str] = None,
475
  use_4bit_quantization: bool = False,
476
- device_map: Optional[str] = "auto",
477
  trust_remote_code: bool = True,
478
  extract_reasoning_steps: bool = True,
479
  **kwargs
@@ -521,7 +521,7 @@ def reasoning_completion(
521
  params["device"] = device
522
  if use_4bit_quantization:
523
  params["use_4bit_quantization"] = use_4bit_quantization
524
- if device_map != "auto":
525
  params["device_map"] = device_map
526
  if not trust_remote_code:
527
  params["trust_remote_code"] = trust_remote_code
 
379
  model: Optional[str] = None,
380
  device: Optional[str] = None,
381
  use_4bit_quantization: bool = False,
382
+ device_map: Optional[str] = None,
383
  trust_remote_code: bool = True,
384
  **kwargs
385
  ) -> Dict[str, Any]:
 
448
  params["device"] = device
449
  if use_4bit_quantization:
450
  params["use_4bit_quantization"] = use_4bit_quantization
451
+ if device_map:
452
  params["device_map"] = device_map
453
  if not trust_remote_code:
454
  params["trust_remote_code"] = trust_remote_code
 
473
  model: Optional[str] = None,
474
  device: Optional[str] = None,
475
  use_4bit_quantization: bool = False,
476
+ device_map: Optional[str] = None,
477
  trust_remote_code: bool = True,
478
  extract_reasoning_steps: bool = True,
479
  **kwargs
 
521
  params["device"] = device
522
  if use_4bit_quantization:
523
  params["use_4bit_quantization"] = use_4bit_quantization
524
+ if device_map:
525
  params["device_map"] = device_map
526
  if not trust_remote_code:
527
  params["trust_remote_code"] = trust_remote_code