Spaces:
Running
Running
Commit
·
d709cdc
1
Parent(s):
814fa89
更新模型初始化,添加设备参数支持,并将device_map默认值修改为None,以提高灵活性和兼容性。
Browse files
examples/simple_llm.py
CHANGED
@@ -16,15 +16,16 @@ if __name__ == "__main__":
|
|
16 |
try:
|
17 |
model_name = "google/gemma-3-4b-it"
|
18 |
use_4bit_quantization = False
|
|
|
19 |
|
20 |
# gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
|
21 |
# 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
|
22 |
if model_name.startswith("mlx-community"):
|
23 |
gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
|
24 |
elif model_name.startswith("microsoft"):
|
25 |
-
gemma_chat = Phi4TransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization)
|
26 |
else:
|
27 |
-
gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization)
|
28 |
|
29 |
print("\n--- 示例 1: 简单用户查询 ---")
|
30 |
messages_example1 = [
|
|
|
16 |
try:
|
17 |
model_name = "google/gemma-3-4b-it"
|
18 |
use_4bit_quantization = False
|
19 |
+
device = "mps"
|
20 |
|
21 |
# gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
|
22 |
# 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
|
23 |
if model_name.startswith("mlx-community"):
|
24 |
gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
|
25 |
elif model_name.startswith("microsoft"):
|
26 |
+
gemma_chat = Phi4TransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device)
|
27 |
else:
|
28 |
+
gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device)
|
29 |
|
30 |
print("\n--- 示例 1: 简单用户查询 ---")
|
31 |
messages_example1 = [
|
src/podcast_transcribe/llm/llm_base.py
CHANGED
@@ -182,7 +182,7 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
|
|
182 |
self,
|
183 |
model_name: str,
|
184 |
use_4bit_quantization: bool = False,
|
185 |
-
device_map: Optional[str] =
|
186 |
device: Optional[str] = None,
|
187 |
trust_remote_code: bool = True,
|
188 |
torch_dtype: Optional[torch.dtype] = None
|
|
|
182 |
self,
|
183 |
model_name: str,
|
184 |
use_4bit_quantization: bool = False,
|
185 |
+
device_map: Optional[str] = None,
|
186 |
device: Optional[str] = None,
|
187 |
trust_remote_code: bool = True,
|
188 |
torch_dtype: Optional[torch.dtype] = None
|
src/podcast_transcribe/llm/llm_gemma_transfomers.py
CHANGED
@@ -11,7 +11,7 @@ class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
|
|
11 |
self,
|
12 |
model_name: str = "google/gemma-3-4b-it",
|
13 |
use_4bit_quantization: bool = False,
|
14 |
-
device_map: Optional[str] =
|
15 |
device: Optional[str] = None,
|
16 |
trust_remote_code: bool = True
|
17 |
):
|
|
|
11 |
self,
|
12 |
model_name: str = "google/gemma-3-4b-it",
|
13 |
use_4bit_quantization: bool = False,
|
14 |
+
device_map: Optional[str] = None,
|
15 |
device: Optional[str] = None,
|
16 |
trust_remote_code: bool = True
|
17 |
):
|
src/podcast_transcribe/llm/llm_phi4_transfomers.py
CHANGED
@@ -11,7 +11,7 @@ class Phi4TransformersChatCompletion(TransformersBaseChatCompletion):
|
|
11 |
self,
|
12 |
model_name: str = "microsoft/Phi-4-mini-reasoning",
|
13 |
use_4bit_quantization: bool = False,
|
14 |
-
device_map: Optional[str] =
|
15 |
device: Optional[str] = None,
|
16 |
trust_remote_code: bool = True
|
17 |
):
|
|
|
11 |
self,
|
12 |
model_name: str = "microsoft/Phi-4-mini-reasoning",
|
13 |
use_4bit_quantization: bool = False,
|
14 |
+
device_map: Optional[str] = None,
|
15 |
device: Optional[str] = None,
|
16 |
trust_remote_code: bool = True
|
17 |
):
|
src/podcast_transcribe/llm/llm_router.py
CHANGED
@@ -379,7 +379,7 @@ def chat_completion(
|
|
379 |
model: Optional[str] = None,
|
380 |
device: Optional[str] = None,
|
381 |
use_4bit_quantization: bool = False,
|
382 |
-
device_map: Optional[str] =
|
383 |
trust_remote_code: bool = True,
|
384 |
**kwargs
|
385 |
) -> Dict[str, Any]:
|
@@ -448,7 +448,7 @@ def chat_completion(
|
|
448 |
params["device"] = device
|
449 |
if use_4bit_quantization:
|
450 |
params["use_4bit_quantization"] = use_4bit_quantization
|
451 |
-
if device_map
|
452 |
params["device_map"] = device_map
|
453 |
if not trust_remote_code:
|
454 |
params["trust_remote_code"] = trust_remote_code
|
@@ -473,7 +473,7 @@ def reasoning_completion(
|
|
473 |
model: Optional[str] = None,
|
474 |
device: Optional[str] = None,
|
475 |
use_4bit_quantization: bool = False,
|
476 |
-
device_map: Optional[str] =
|
477 |
trust_remote_code: bool = True,
|
478 |
extract_reasoning_steps: bool = True,
|
479 |
**kwargs
|
@@ -521,7 +521,7 @@ def reasoning_completion(
|
|
521 |
params["device"] = device
|
522 |
if use_4bit_quantization:
|
523 |
params["use_4bit_quantization"] = use_4bit_quantization
|
524 |
-
if device_map
|
525 |
params["device_map"] = device_map
|
526 |
if not trust_remote_code:
|
527 |
params["trust_remote_code"] = trust_remote_code
|
|
|
379 |
model: Optional[str] = None,
|
380 |
device: Optional[str] = None,
|
381 |
use_4bit_quantization: bool = False,
|
382 |
+
device_map: Optional[str] = None,
|
383 |
trust_remote_code: bool = True,
|
384 |
**kwargs
|
385 |
) -> Dict[str, Any]:
|
|
|
448 |
params["device"] = device
|
449 |
if use_4bit_quantization:
|
450 |
params["use_4bit_quantization"] = use_4bit_quantization
|
451 |
+
if device_map:
|
452 |
params["device_map"] = device_map
|
453 |
if not trust_remote_code:
|
454 |
params["trust_remote_code"] = trust_remote_code
|
|
|
473 |
model: Optional[str] = None,
|
474 |
device: Optional[str] = None,
|
475 |
use_4bit_quantization: bool = False,
|
476 |
+
device_map: Optional[str] = None,
|
477 |
trust_remote_code: bool = True,
|
478 |
extract_reasoning_steps: bool = True,
|
479 |
**kwargs
|
|
|
521 |
params["device"] = device
|
522 |
if use_4bit_quantization:
|
523 |
params["use_4bit_quantization"] = use_4bit_quantization
|
524 |
+
if device_map:
|
525 |
params["device_map"] = device_map
|
526 |
if not trust_remote_code:
|
527 |
params["trust_remote_code"] = trust_remote_code
|