Spaces:
Sleeping
Sleeping
Upload run_cloud_training.py with huggingface_hub
Browse files- run_cloud_training.py +3 -30
run_cloud_training.py
CHANGED
@@ -469,9 +469,7 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
469 |
max_seq_length=max_seq_length,
|
470 |
dtype=dtype,
|
471 |
quantization_config=bnb_config,
|
472 |
-
attn_implementation="eager"
|
473 |
-
use_flash_attention=False, # Explicitly disable flash attention
|
474 |
-
use_xformers_attention=False # Explicitly disable xformers
|
475 |
)
|
476 |
logger.info("Model loaded successfully with unsloth")
|
477 |
|
@@ -479,12 +477,6 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
479 |
if hasattr(model, 'config'):
|
480 |
if hasattr(model.config, 'attn_implementation'):
|
481 |
model.config.attn_implementation = "eager"
|
482 |
-
if hasattr(model.config, 'use_flash_attention'):
|
483 |
-
model.config.use_flash_attention = False
|
484 |
-
if hasattr(model.config, 'use_flash_attention_2'):
|
485 |
-
model.config.use_flash_attention_2 = False
|
486 |
-
if hasattr(model.config, 'use_xformers_attention'):
|
487 |
-
model.config.use_xformers_attention = False
|
488 |
|
489 |
return model, tokenizer
|
490 |
|
@@ -494,9 +486,7 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
494 |
|
495 |
# We'll try with HF loading
|
496 |
attn_params = {
|
497 |
-
"attn_implementation": "eager"
|
498 |
-
"use_flash_attention": False, # Explicitly disable flash attention
|
499 |
-
"use_xformers_attention": False # Explicitly disable xformers
|
500 |
}
|
501 |
|
502 |
# Approach 1: Using attn_implementation parameter (newer method)
|
@@ -507,12 +497,6 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
507 |
# Disable flash attention in config
|
508 |
if hasattr(config, 'attn_implementation'):
|
509 |
config.attn_implementation = "eager"
|
510 |
-
if hasattr(config, 'use_flash_attention'):
|
511 |
-
config.use_flash_attention = False
|
512 |
-
if hasattr(config, 'use_flash_attention_2'):
|
513 |
-
config.use_flash_attention_2 = False
|
514 |
-
if hasattr(config, 'use_xformers_attention'):
|
515 |
-
config.use_xformers_attention = False
|
516 |
|
517 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
518 |
|
@@ -539,12 +523,6 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
539 |
# Disable flash attention in config
|
540 |
if hasattr(config, 'attn_implementation'):
|
541 |
config.attn_implementation = "eager"
|
542 |
-
if hasattr(config, 'use_flash_attention'):
|
543 |
-
config.use_flash_attention = False
|
544 |
-
if hasattr(config, 'use_flash_attention_2'):
|
545 |
-
config.use_flash_attention_2 = False
|
546 |
-
if hasattr(config, 'use_xformers_attention'):
|
547 |
-
config.use_xformers_attention = False
|
548 |
|
549 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
550 |
|
@@ -556,9 +534,7 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
556 |
torch_dtype=dtype or torch.float16,
|
557 |
quantization_config=bnb_config,
|
558 |
trust_remote_code=True,
|
559 |
-
attn_implementation="eager"
|
560 |
-
use_flash_attention=False,
|
561 |
-
use_xformers_attention=False
|
562 |
)
|
563 |
logger.info("Model loaded successfully with basic HF loading")
|
564 |
return model, tokenizer
|
@@ -614,8 +590,6 @@ def train(config_path, dataset_name, output_dir):
|
|
614 |
|
615 |
# Update hardware config to ensure eager attention
|
616 |
hardware_config["attn_implementation"] = "eager"
|
617 |
-
hardware_config["use_flash_attention"] = False
|
618 |
-
hardware_config["use_xformers_attention"] = False
|
619 |
|
620 |
# Verify this is training phase only
|
621 |
training_phase_only = dataset_config.get("training_phase_only", True)
|
@@ -678,7 +652,6 @@ def train(config_path, dataset_name, output_dir):
|
|
678 |
|
679 |
# Update hardware config to ensure eager attention
|
680 |
hardware_config["attn_implementation"] = "eager"
|
681 |
-
hardware_config["use_flash_attention"] = False
|
682 |
|
683 |
model, tokenizer = load_model_safely(model_name, max_seq_length, dtype)
|
684 |
|
|
|
469 |
max_seq_length=max_seq_length,
|
470 |
dtype=dtype,
|
471 |
quantization_config=bnb_config,
|
472 |
+
attn_implementation="eager" # Force eager attention
|
|
|
|
|
473 |
)
|
474 |
logger.info("Model loaded successfully with unsloth")
|
475 |
|
|
|
477 |
if hasattr(model, 'config'):
|
478 |
if hasattr(model.config, 'attn_implementation'):
|
479 |
model.config.attn_implementation = "eager"
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
|
481 |
return model, tokenizer
|
482 |
|
|
|
486 |
|
487 |
# We'll try with HF loading
|
488 |
attn_params = {
|
489 |
+
"attn_implementation": "eager" # Always use eager
|
|
|
|
|
490 |
}
|
491 |
|
492 |
# Approach 1: Using attn_implementation parameter (newer method)
|
|
|
497 |
# Disable flash attention in config
|
498 |
if hasattr(config, 'attn_implementation'):
|
499 |
config.attn_implementation = "eager"
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
|
501 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
502 |
|
|
|
523 |
# Disable flash attention in config
|
524 |
if hasattr(config, 'attn_implementation'):
|
525 |
config.attn_implementation = "eager"
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
|
527 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
528 |
|
|
|
534 |
torch_dtype=dtype or torch.float16,
|
535 |
quantization_config=bnb_config,
|
536 |
trust_remote_code=True,
|
537 |
+
attn_implementation="eager"
|
|
|
|
|
538 |
)
|
539 |
logger.info("Model loaded successfully with basic HF loading")
|
540 |
return model, tokenizer
|
|
|
590 |
|
591 |
# Update hardware config to ensure eager attention
|
592 |
hardware_config["attn_implementation"] = "eager"
|
|
|
|
|
593 |
|
594 |
# Verify this is training phase only
|
595 |
training_phase_only = dataset_config.get("training_phase_only", True)
|
|
|
652 |
|
653 |
# Update hardware config to ensure eager attention
|
654 |
hardware_config["attn_implementation"] = "eager"
|
|
|
655 |
|
656 |
model, tokenizer = load_model_safely(model_name, max_seq_length, dtype)
|
657 |
|