Spaces:
Sleeping
Sleeping
Upload run_cloud_training.py with huggingface_hub
Browse files- run_cloud_training.py +54 -13
run_cloud_training.py
CHANGED
@@ -407,6 +407,10 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
|
|
407 |
"""
|
408 |
logger.info(f"Loading model: {model_name}")
|
409 |
|
|
|
|
|
|
|
|
|
410 |
# Create BitsAndBytesConfig for 4-bit quantization
|
411 |
from transformers import BitsAndBytesConfig
|
412 |
bnb_config = BitsAndBytesConfig(
|
@@ -416,14 +420,9 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
|
|
416 |
bnb_4bit_use_double_quant=True
|
417 |
)
|
418 |
|
419 |
-
#
|
420 |
-
attn_implementation = "
|
421 |
-
|
422 |
-
if use_flash_attention and flash_attention_available:
|
423 |
-
logger.info("Using Flash Attention for faster training")
|
424 |
-
attn_implementation = "flash_attention_2"
|
425 |
-
else:
|
426 |
-
logger.info("Using standard attention mechanism (sdpa)")
|
427 |
|
428 |
# Try loading with unsloth
|
429 |
try:
|
@@ -436,6 +435,12 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
|
|
436 |
attn_implementation=attn_implementation
|
437 |
)
|
438 |
logger.info("Model loaded successfully with unsloth")
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
return model, tokenizer
|
440 |
|
441 |
except Exception as e:
|
@@ -444,6 +449,10 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
|
|
444 |
|
445 |
# Fallback to standard HF loading
|
446 |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
447 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
448 |
|
449 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -464,6 +473,32 @@ def train(config_path, dataset_name, output_dir):
|
|
464 |
load_dotenv()
|
465 |
config = load_config(config_path)
|
466 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
# Extract configs
|
468 |
model_config = config.get("model_config", {})
|
469 |
training_config = config.get("training_config", {})
|
@@ -513,11 +548,11 @@ def train(config_path, dataset_name, output_dir):
|
|
513 |
target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
|
514 |
)
|
515 |
|
516 |
-
#
|
517 |
-
use_flash_attention =
|
518 |
|
519 |
# Initialize model with our safe loading function
|
520 |
-
logger.info("Loading pre-quantized model")
|
521 |
dtype = torch.float16 if hardware_config.get("fp16", True) else None
|
522 |
model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention)
|
523 |
|
@@ -531,7 +566,10 @@ def train(config_path, dataset_name, output_dir):
|
|
531 |
from peft import get_peft_model
|
532 |
model = get_peft_model(model, lora_config_obj)
|
533 |
logger.info("Successfully applied LoRA with standard PEFT")
|
534 |
-
|
|
|
|
|
|
|
535 |
# No need to format the dataset - it's already pre-tokenized
|
536 |
logger.info("Using dataset with flexible tokenization handling")
|
537 |
logger.info("Will use pre-tokenized data if available, or tokenize strings as fallback")
|
@@ -627,10 +665,13 @@ if __name__ == "__main__":
|
|
627 |
parser.add_argument("--output_dir", type=str, default=None,
|
628 |
help="Output directory for the fine-tuned model")
|
629 |
parser.add_argument("--use_flash_attention", action="store_true",
|
630 |
-
help="Use Flash Attention if available")
|
631 |
|
632 |
args = parser.parse_args()
|
633 |
|
|
|
|
|
|
|
634 |
# Run training - Research phase only
|
635 |
try:
|
636 |
output_path = train(args.config, args.dataset, args.output_dir)
|
|
|
407 |
"""
|
408 |
logger.info(f"Loading model: {model_name}")
|
409 |
|
410 |
+
# Explicitly disable xformers and flash attention in environment
|
411 |
+
os.environ["XFORMERS_DISABLED"] = "1"
|
412 |
+
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
413 |
+
|
414 |
# Create BitsAndBytesConfig for 4-bit quantization
|
415 |
from transformers import BitsAndBytesConfig
|
416 |
bnb_config = BitsAndBytesConfig(
|
|
|
420 |
bnb_4bit_use_double_quant=True
|
421 |
)
|
422 |
|
423 |
+
# Force eager implementation to avoid BMGHK format issues
|
424 |
+
attn_implementation = "eager" # Use eager implementation to avoid BMGHK format issues
|
425 |
+
logger.info(f"Forcing eager attention implementation to avoid BMGHK format issues")
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
# Try loading with unsloth
|
428 |
try:
|
|
|
435 |
attn_implementation=attn_implementation
|
436 |
)
|
437 |
logger.info("Model loaded successfully with unsloth")
|
438 |
+
|
439 |
+
# Explicitly set attention implementation in model config
|
440 |
+
if hasattr(model, 'config'):
|
441 |
+
model.config.attn_implementation = attn_implementation
|
442 |
+
logger.info(f"Explicitly set model config attention implementation to {attn_implementation}")
|
443 |
+
|
444 |
return model, tokenizer
|
445 |
|
446 |
except Exception as e:
|
|
|
449 |
|
450 |
# Fallback to standard HF loading
|
451 |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
452 |
+
|
453 |
+
# Set attention implementation in config
|
454 |
+
config.attn_implementation = attn_implementation
|
455 |
+
|
456 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
457 |
|
458 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
473 |
load_dotenv()
|
474 |
config = load_config(config_path)
|
475 |
|
476 |
+
# Explicitly disable xformers and flash attention in environment
|
477 |
+
os.environ["XFORMERS_DISABLED"] = "1"
|
478 |
+
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
479 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
480 |
+
|
481 |
+
# Try to unload xformers if it's loaded
|
482 |
+
if 'xformers' in sys.modules:
|
483 |
+
logger.info("Removing xformers from sys.modules")
|
484 |
+
del sys.modules['xformers']
|
485 |
+
|
486 |
+
# Patch torch.nn.functional to avoid memory_efficient_attention
|
487 |
+
try:
|
488 |
+
import torch.nn.functional as F
|
489 |
+
if hasattr(F, 'scaled_dot_product_attention'):
|
490 |
+
logger.info("Patching torch.nn.functional.scaled_dot_product_attention")
|
491 |
+
original_sdpa = F.scaled_dot_product_attention
|
492 |
+
|
493 |
+
def safe_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
494 |
+
# Force disable memory efficient attention
|
495 |
+
logger.info("Using safe scaled_dot_product_attention (no xformers)")
|
496 |
+
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
497 |
+
|
498 |
+
F.scaled_dot_product_attention = safe_sdpa
|
499 |
+
except Exception as e:
|
500 |
+
logger.warning(f"Failed to patch scaled_dot_product_attention: {e}")
|
501 |
+
|
502 |
# Extract configs
|
503 |
model_config = config.get("model_config", {})
|
504 |
training_config = config.get("training_config", {})
|
|
|
548 |
target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
|
549 |
)
|
550 |
|
551 |
+
# Force eager attention implementation
|
552 |
+
use_flash_attention = False # Override to force eager implementation
|
553 |
|
554 |
# Initialize model with our safe loading function
|
555 |
+
logger.info("Loading pre-quantized model with eager attention")
|
556 |
dtype = torch.float16 if hardware_config.get("fp16", True) else None
|
557 |
model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention)
|
558 |
|
|
|
566 |
from peft import get_peft_model
|
567 |
model = get_peft_model(model, lora_config_obj)
|
568 |
logger.info("Successfully applied LoRA with standard PEFT")
|
569 |
+
|
570 |
+
# Explicitly set attention implementation in model config again after PEFT
|
571 |
+
model.config.attn_implementation = "eager"
|
572 |
+
|
573 |
# No need to format the dataset - it's already pre-tokenized
|
574 |
logger.info("Using dataset with flexible tokenization handling")
|
575 |
logger.info("Will use pre-tokenized data if available, or tokenize strings as fallback")
|
|
|
665 |
parser.add_argument("--output_dir", type=str, default=None,
|
666 |
help="Output directory for the fine-tuned model")
|
667 |
parser.add_argument("--use_flash_attention", action="store_true",
|
668 |
+
help="Use Flash Attention if available (NOT RECOMMENDED)")
|
669 |
|
670 |
args = parser.parse_args()
|
671 |
|
672 |
+
# Override flash attention setting to force eager implementation
|
673 |
+
args.use_flash_attention = False
|
674 |
+
|
675 |
# Run training - Research phase only
|
676 |
try:
|
677 |
output_path = train(args.config, args.dataset, args.output_dir)
|