fix model parallel (#816)
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -442,14 +442,7 @@ def load_model(
|
|
| 442 |
if cfg.ddp and not load_in_8bit:
|
| 443 |
model.to(f"cuda:{cfg.local_rank}")
|
| 444 |
|
| 445 |
-
if (
|
| 446 |
-
torch.cuda.device_count() > 1
|
| 447 |
-
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
| 448 |
-
and (cfg.load_in_4bit)
|
| 449 |
-
):
|
| 450 |
-
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
| 451 |
-
# so let's only set it for the 4bit, see
|
| 452 |
-
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
| 453 |
setattr(model, "is_parallelizable", True)
|
| 454 |
setattr(model, "model_parallel", True)
|
| 455 |
|
|
|
|
| 442 |
if cfg.ddp and not load_in_8bit:
|
| 443 |
model.to(f"cuda:{cfg.local_rank}")
|
| 444 |
|
| 445 |
+
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
setattr(model, "is_parallelizable", True)
|
| 447 |
setattr(model, "model_parallel", True)
|
| 448 |
|