George-API commited on
Commit
6704e73
·
verified ·
1 Parent(s): 65829fc

Upload run_cloud_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_cloud_training.py +54 -13
run_cloud_training.py CHANGED
@@ -407,6 +407,10 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
407
  """
408
  logger.info(f"Loading model: {model_name}")
409
 
 
 
 
 
410
  # Create BitsAndBytesConfig for 4-bit quantization
411
  from transformers import BitsAndBytesConfig
412
  bnb_config = BitsAndBytesConfig(
@@ -416,14 +420,9 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
416
  bnb_4bit_use_double_quant=True
417
  )
418
 
419
- # Determine appropriate attention implementation
420
- attn_implementation = "sdpa" # Default to PyTorch's scaled dot product attention
421
-
422
- if use_flash_attention and flash_attention_available:
423
- logger.info("Using Flash Attention for faster training")
424
- attn_implementation = "flash_attention_2"
425
- else:
426
- logger.info("Using standard attention mechanism (sdpa)")
427
 
428
  # Try loading with unsloth
429
  try:
@@ -436,6 +435,12 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
436
  attn_implementation=attn_implementation
437
  )
438
  logger.info("Model loaded successfully with unsloth")
 
 
 
 
 
 
439
  return model, tokenizer
440
 
441
  except Exception as e:
@@ -444,6 +449,10 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
444
 
445
  # Fallback to standard HF loading
446
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
447
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
448
 
449
  model = AutoModelForCausalLM.from_pretrained(
@@ -464,6 +473,32 @@ def train(config_path, dataset_name, output_dir):
464
  load_dotenv()
465
  config = load_config(config_path)
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  # Extract configs
468
  model_config = config.get("model_config", {})
469
  training_config = config.get("training_config", {})
@@ -513,11 +548,11 @@ def train(config_path, dataset_name, output_dir):
513
  target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
514
  )
515
 
516
- # Determine if we should use flash attention
517
- use_flash_attention = hardware_config.get("use_flash_attention", False)
518
 
519
  # Initialize model with our safe loading function
520
- logger.info("Loading pre-quantized model")
521
  dtype = torch.float16 if hardware_config.get("fp16", True) else None
522
  model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention)
523
 
@@ -531,7 +566,10 @@ def train(config_path, dataset_name, output_dir):
531
  from peft import get_peft_model
532
  model = get_peft_model(model, lora_config_obj)
533
  logger.info("Successfully applied LoRA with standard PEFT")
534
-
 
 
 
535
  # No need to format the dataset - it's already pre-tokenized
536
  logger.info("Using dataset with flexible tokenization handling")
537
  logger.info("Will use pre-tokenized data if available, or tokenize strings as fallback")
@@ -627,10 +665,13 @@ if __name__ == "__main__":
627
  parser.add_argument("--output_dir", type=str, default=None,
628
  help="Output directory for the fine-tuned model")
629
  parser.add_argument("--use_flash_attention", action="store_true",
630
- help="Use Flash Attention if available")
631
 
632
  args = parser.parse_args()
633
 
 
 
 
634
  # Run training - Research phase only
635
  try:
636
  output_path = train(args.config, args.dataset, args.output_dir)
 
407
  """
408
  logger.info(f"Loading model: {model_name}")
409
 
410
+ # Explicitly disable xformers and flash attention in environment
411
+ os.environ["XFORMERS_DISABLED"] = "1"
412
+ os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
413
+
414
  # Create BitsAndBytesConfig for 4-bit quantization
415
  from transformers import BitsAndBytesConfig
416
  bnb_config = BitsAndBytesConfig(
 
420
  bnb_4bit_use_double_quant=True
421
  )
422
 
423
+ # Force eager implementation to avoid BMGHK format issues
424
+ attn_implementation = "eager" # Use eager implementation to avoid BMGHK format issues
425
+ logger.info(f"Forcing eager attention implementation to avoid BMGHK format issues")
 
 
 
 
 
426
 
427
  # Try loading with unsloth
428
  try:
 
435
  attn_implementation=attn_implementation
436
  )
437
  logger.info("Model loaded successfully with unsloth")
438
+
439
+ # Explicitly set attention implementation in model config
440
+ if hasattr(model, 'config'):
441
+ model.config.attn_implementation = attn_implementation
442
+ logger.info(f"Explicitly set model config attention implementation to {attn_implementation}")
443
+
444
  return model, tokenizer
445
 
446
  except Exception as e:
 
449
 
450
  # Fallback to standard HF loading
451
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
452
+
453
+ # Set attention implementation in config
454
+ config.attn_implementation = attn_implementation
455
+
456
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
457
 
458
  model = AutoModelForCausalLM.from_pretrained(
 
473
  load_dotenv()
474
  config = load_config(config_path)
475
 
476
+ # Explicitly disable xformers and flash attention in environment
477
+ os.environ["XFORMERS_DISABLED"] = "1"
478
+ os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
479
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
480
+
481
+ # Try to unload xformers if it's loaded
482
+ if 'xformers' in sys.modules:
483
+ logger.info("Removing xformers from sys.modules")
484
+ del sys.modules['xformers']
485
+
486
+ # Patch torch.nn.functional to avoid memory_efficient_attention
487
+ try:
488
+ import torch.nn.functional as F
489
+ if hasattr(F, 'scaled_dot_product_attention'):
490
+ logger.info("Patching torch.nn.functional.scaled_dot_product_attention")
491
+ original_sdpa = F.scaled_dot_product_attention
492
+
493
+ def safe_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
494
+ # Force disable memory efficient attention
495
+ logger.info("Using safe scaled_dot_product_attention (no xformers)")
496
+ return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
497
+
498
+ F.scaled_dot_product_attention = safe_sdpa
499
+ except Exception as e:
500
+ logger.warning(f"Failed to patch scaled_dot_product_attention: {e}")
501
+
502
  # Extract configs
503
  model_config = config.get("model_config", {})
504
  training_config = config.get("training_config", {})
 
548
  target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
549
  )
550
 
551
+ # Force eager attention implementation
552
+ use_flash_attention = False # Override to force eager implementation
553
 
554
  # Initialize model with our safe loading function
555
+ logger.info("Loading pre-quantized model with eager attention")
556
  dtype = torch.float16 if hardware_config.get("fp16", True) else None
557
  model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention)
558
 
 
566
  from peft import get_peft_model
567
  model = get_peft_model(model, lora_config_obj)
568
  logger.info("Successfully applied LoRA with standard PEFT")
569
+
570
+ # Explicitly set attention implementation in model config again after PEFT
571
+ model.config.attn_implementation = "eager"
572
+
573
  # No need to format the dataset - it's already pre-tokenized
574
  logger.info("Using dataset with flexible tokenization handling")
575
  logger.info("Will use pre-tokenized data if available, or tokenize strings as fallback")
 
665
  parser.add_argument("--output_dir", type=str, default=None,
666
  help="Output directory for the fine-tuned model")
667
  parser.add_argument("--use_flash_attention", action="store_true",
668
+ help="Use Flash Attention if available (NOT RECOMMENDED)")
669
 
670
  args = parser.parse_args()
671
 
672
+ # Override flash attention setting to force eager implementation
673
+ args.use_flash_attention = False
674
+
675
  # Run training - Research phase only
676
  try:
677
  output_path = train(args.config, args.dataset, args.output_dir)