George-API commited on
Commit
65829fc
·
verified ·
1 Parent(s): 4d8bc74

Upload run_cloud_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_cloud_training.py +57 -193
run_cloud_training.py CHANGED
@@ -401,147 +401,62 @@ def remove_training_marker():
401
  os.remove("TRAINING_ACTIVE")
402
  logger.info("Removed training active marker")
403
 
404
- def load_model_safely(model_name, max_seq_length, dtype=None):
405
  """
406
- Load the model in a safe way that works with Qwen models
407
- by trying different loading strategies.
408
  """
409
- global flash_attention_available
410
 
411
- # Force disable flash attention and xformers
412
- flash_attention_available = False
413
- os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
414
- os.environ["XFORMERS_DISABLED"] = "1"
 
 
 
 
415
 
416
- # Patch transformers attention implementation
417
- try:
418
- # Try to patch transformers attention implementation to avoid xformers
419
- import transformers.models.llama.modeling_llama as llama_modeling
420
-
421
- # Store original attention implementation
422
- if not hasattr(llama_modeling, '_original_forward'):
423
- # Only patch if not already patched
424
- logger.info("Patching LLaMA attention implementation to avoid xformers")
425
-
426
- # Store original implementation
427
- if hasattr(llama_modeling.LlamaAttention, 'forward'):
428
- llama_modeling._original_forward = llama_modeling.LlamaAttention.forward
429
-
430
- # Define a new forward method that doesn't use xformers
431
- def safe_attention_forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False):
432
- logger.info("Using safe attention implementation (no xformers)")
433
-
434
- # Force use_flash_attention to False
435
- self._attn_implementation = "eager"
436
- if hasattr(self, 'use_flash_attention'):
437
- self.use_flash_attention = False
438
- if hasattr(self, 'use_flash_attention_2'):
439
- self.use_flash_attention_2 = False
440
-
441
- # Call original implementation with flash attention disabled
442
- return llama_modeling._original_forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
443
-
444
- # Replace the forward method
445
- llama_modeling.LlamaAttention.forward = safe_attention_forward
446
- logger.info("Successfully patched LLaMA attention implementation")
447
- except Exception as e:
448
- logger.warning(f"Failed to patch attention implementation: {e}")
449
- logger.info("Will try to proceed with standard loading")
450
 
 
451
  try:
452
- logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
453
-
454
- # Create BitsAndBytesConfig for 4-bit quantization
455
- from transformers import BitsAndBytesConfig
456
- bnb_config = BitsAndBytesConfig(
457
- load_in_4bit=True,
458
- bnb_4bit_compute_dtype=torch.float16,
459
- bnb_4bit_quant_type="nf4",
460
- bnb_4bit_use_double_quant=True
461
  )
 
 
462
 
463
- # First try loading with unsloth but without flash attention
464
- try:
465
- logger.info("Loading model with unsloth optimizations")
466
- # Don't pass any flash attention parameters to unsloth
467
- model, tokenizer = FastLanguageModel.from_pretrained(
468
- model_name=model_name,
469
- max_seq_length=max_seq_length,
470
- dtype=dtype,
471
- quantization_config=bnb_config,
472
- attn_implementation="eager" # Force eager attention
473
- )
474
- logger.info("Model loaded successfully with unsloth")
475
-
476
- # Explicitly disable flash attention in model config
477
- if hasattr(model, 'config'):
478
- if hasattr(model.config, 'attn_implementation'):
479
- model.config.attn_implementation = "eager"
480
-
481
- return model, tokenizer
482
-
483
- except Exception as e:
484
- logger.warning(f"Unsloth loading failed: {e}")
485
- logger.info("Falling back to standard Hugging Face loading...")
486
-
487
- # We'll try with HF loading
488
- attn_params = {
489
- "attn_implementation": "eager" # Always use eager
490
- }
491
-
492
- # Approach 1: Using attn_implementation parameter (newer method)
493
- try:
494
- logger.info(f"Trying HF loading with attention parameters: {attn_params}")
495
- config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
496
-
497
- # Disable flash attention in config
498
- if hasattr(config, 'attn_implementation'):
499
- config.attn_implementation = "eager"
500
-
501
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
502
-
503
- # The proper way to set attention implementation in newer transformers
504
- model = AutoModelForCausalLM.from_pretrained(
505
- model_name,
506
- config=config,
507
- device_map="auto",
508
- torch_dtype=dtype or torch.float16,
509
- quantization_config=bnb_config,
510
- trust_remote_code=True,
511
- **attn_params
512
- )
513
- logger.info(f"Model loaded successfully with HF using attention parameters: {attn_params}")
514
- return model, tokenizer
515
-
516
- except Exception as e:
517
- logger.warning(f"HF loading with attn_implementation failed: {e}")
518
- logger.info("Trying fallback method...")
519
-
520
- # Approach 2: Complete fallback with minimal parameters
521
- config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
522
-
523
- # Disable flash attention in config
524
- if hasattr(config, 'attn_implementation'):
525
- config.attn_implementation = "eager"
526
-
527
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
528
-
529
- # Most basic loading without any attention parameters
530
- model = AutoModelForCausalLM.from_pretrained(
531
- model_name,
532
- config=config,
533
- device_map="auto",
534
- torch_dtype=dtype or torch.float16,
535
- quantization_config=bnb_config,
536
- trust_remote_code=True,
537
- attn_implementation="eager"
538
- )
539
- logger.info("Model loaded successfully with basic HF loading")
540
- return model, tokenizer
541
-
542
  except Exception as e:
543
- logger.error(f"All model loading attempts failed: {e}")
544
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
  def train(config_path, dataset_name, output_dir):
547
  """Main training function - RESEARCH TRAINING PHASE ONLY"""
@@ -556,50 +471,6 @@ def train(config_path, dataset_name, output_dir):
556
  lora_config = config.get("lora_config", {})
557
  dataset_config = config.get("dataset_config", {})
558
 
559
- # Force disable flash attention and xformers
560
- os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
561
- os.environ["XFORMERS_DISABLED"] = "1"
562
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
563
-
564
- # Monkey patch torch.nn.functional to disable memory_efficient_attention
565
- try:
566
- import torch.nn.functional as F
567
- if hasattr(F, 'scaled_dot_product_attention'):
568
- logger.info("Monkey patching torch.nn.functional.scaled_dot_product_attention")
569
- original_sdpa = F.scaled_dot_product_attention
570
-
571
- def safe_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
572
- # Force disable memory efficient attention
573
- logger.info("Using safe scaled_dot_product_attention (no xformers)")
574
- return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
575
-
576
- F.scaled_dot_product_attention = safe_sdpa
577
- except Exception as e:
578
- logger.warning(f"Failed to patch scaled_dot_product_attention: {e}")
579
-
580
- # Completely remove xformers from sys.modules if it's loaded
581
- for module_name in list(sys.modules.keys()):
582
- if 'xformers' in module_name:
583
- logger.info(f"Removing {module_name} from sys.modules")
584
- del sys.modules[module_name]
585
-
586
- # Update flash attention setting to always use eager
587
- global flash_attention_available
588
- flash_attention_available = False
589
- logger.info("Flash Attention has been DISABLED globally")
590
-
591
- # Update hardware config to ensure eager attention
592
- hardware_config["attn_implementation"] = "eager"
593
-
594
- # Verify this is training phase only
595
- training_phase_only = dataset_config.get("training_phase_only", True)
596
- if not training_phase_only:
597
- logger.warning("This script is meant for research training phase only")
598
- logger.warning("Setting training_phase_only=True")
599
-
600
- # Verify dataset is pre-tokenized
601
- logger.info("IMPORTANT: Using pre-tokenized dataset - No tokenization will be performed")
602
-
603
  # Set the output directory
604
  output_dir = output_dir or training_config.get("output_dir", "fine_tuned_model")
605
  os.makedirs(output_dir, exist_ok=True)
@@ -628,8 +499,8 @@ def train(config_path, dataset_name, output_dir):
628
  )
629
  tokenizer.pad_token = tokenizer.eos_token
630
 
631
- # Initialize model with unsloth
632
- logger.info("Initializing model with unsloth (preserving 4-bit quantization)")
633
  max_seq_length = training_config.get("max_seq_length", 2048)
634
 
635
  # Create LoRA config directly
@@ -642,29 +513,21 @@ def train(config_path, dataset_name, output_dir):
642
  target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
643
  )
644
 
 
 
 
645
  # Initialize model with our safe loading function
646
- logger.info("Loading pre-quantized model safely")
647
  dtype = torch.float16 if hardware_config.get("fp16", True) else None
648
-
649
- # Force eager attention implementation
650
- os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
651
- logger.info("Flash attention has been DISABLED globally via environment variable")
652
-
653
- # Update hardware config to ensure eager attention
654
- hardware_config["attn_implementation"] = "eager"
655
-
656
- model, tokenizer = load_model_safely(model_name, max_seq_length, dtype)
657
 
658
  # Disable generation capabilities for research training
659
  logger.info("Disabling generation capabilities - Research training only")
660
  model.config.is_decoder = False
661
  model.config.task_specific_params = None
662
 
663
- # Try different approaches to apply LoRA
664
  logger.info("Applying LoRA to model")
665
-
666
- # Skip unsloth's method and go directly to PEFT
667
- logger.info("Using standard PEFT method to apply LoRA")
668
  from peft import get_peft_model
669
  model = get_peft_model(model, lora_config_obj)
670
  logger.info("Successfully applied LoRA with standard PEFT")
@@ -692,7 +555,6 @@ def train(config_path, dataset_name, output_dir):
692
  logger.warning("No reporting backends available - training metrics won't be logged")
693
 
694
  # Set up training arguments with correct parameters
695
- # Extract only the valid parameters from hardware_config
696
  training_args_dict = {
697
  "output_dir": output_dir,
698
  "num_train_epochs": training_config.get("num_train_epochs", 3),
@@ -764,6 +626,8 @@ if __name__ == "__main__":
764
  help="Dataset name or path")
765
  parser.add_argument("--output_dir", type=str, default=None,
766
  help="Output directory for the fine-tuned model")
 
 
767
 
768
  args = parser.parse_args()
769
 
 
401
  os.remove("TRAINING_ACTIVE")
402
  logger.info("Removed training active marker")
403
 
404
+ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attention=False):
405
  """
406
+ Load the model with appropriate attention settings based on hardware capability
 
407
  """
408
+ logger.info(f"Loading model: {model_name}")
409
 
410
+ # Create BitsAndBytesConfig for 4-bit quantization
411
+ from transformers import BitsAndBytesConfig
412
+ bnb_config = BitsAndBytesConfig(
413
+ load_in_4bit=True,
414
+ bnb_4bit_compute_dtype=torch.float16,
415
+ bnb_4bit_quant_type="nf4",
416
+ bnb_4bit_use_double_quant=True
417
+ )
418
 
419
+ # Determine appropriate attention implementation
420
+ attn_implementation = "sdpa" # Default to PyTorch's scaled dot product attention
421
+
422
+ if use_flash_attention and flash_attention_available:
423
+ logger.info("Using Flash Attention for faster training")
424
+ attn_implementation = "flash_attention_2"
425
+ else:
426
+ logger.info("Using standard attention mechanism (sdpa)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
+ # Try loading with unsloth
429
  try:
430
+ logger.info("Loading model with unsloth optimizations")
431
+ model, tokenizer = FastLanguageModel.from_pretrained(
432
+ model_name=model_name,
433
+ max_seq_length=max_seq_length,
434
+ dtype=dtype,
435
+ quantization_config=bnb_config,
436
+ attn_implementation=attn_implementation
 
 
437
  )
438
+ logger.info("Model loaded successfully with unsloth")
439
+ return model, tokenizer
440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  except Exception as e:
442
+ logger.warning(f"Unsloth loading failed: {e}")
443
+ logger.info("Falling back to standard Hugging Face loading...")
444
+
445
+ # Fallback to standard HF loading
446
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
447
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
448
+
449
+ model = AutoModelForCausalLM.from_pretrained(
450
+ model_name,
451
+ config=config,
452
+ device_map="auto",
453
+ torch_dtype=dtype or torch.float16,
454
+ quantization_config=bnb_config,
455
+ trust_remote_code=True,
456
+ attn_implementation=attn_implementation
457
+ )
458
+ logger.info("Model loaded successfully with standard HF loading")
459
+ return model, tokenizer
460
 
461
  def train(config_path, dataset_name, output_dir):
462
  """Main training function - RESEARCH TRAINING PHASE ONLY"""
 
471
  lora_config = config.get("lora_config", {})
472
  dataset_config = config.get("dataset_config", {})
473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  # Set the output directory
475
  output_dir = output_dir or training_config.get("output_dir", "fine_tuned_model")
476
  os.makedirs(output_dir, exist_ok=True)
 
499
  )
500
  tokenizer.pad_token = tokenizer.eos_token
501
 
502
+ # Initialize model
503
+ logger.info("Initializing model (preserving 4-bit quantization)")
504
  max_seq_length = training_config.get("max_seq_length", 2048)
505
 
506
  # Create LoRA config directly
 
513
  target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
514
  )
515
 
516
+ # Determine if we should use flash attention
517
+ use_flash_attention = hardware_config.get("use_flash_attention", False)
518
+
519
  # Initialize model with our safe loading function
520
+ logger.info("Loading pre-quantized model")
521
  dtype = torch.float16 if hardware_config.get("fp16", True) else None
522
+ model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention)
 
 
 
 
 
 
 
 
523
 
524
  # Disable generation capabilities for research training
525
  logger.info("Disabling generation capabilities - Research training only")
526
  model.config.is_decoder = False
527
  model.config.task_specific_params = None
528
 
529
+ # Apply LoRA to model
530
  logger.info("Applying LoRA to model")
 
 
 
531
  from peft import get_peft_model
532
  model = get_peft_model(model, lora_config_obj)
533
  logger.info("Successfully applied LoRA with standard PEFT")
 
555
  logger.warning("No reporting backends available - training metrics won't be logged")
556
 
557
  # Set up training arguments with correct parameters
 
558
  training_args_dict = {
559
  "output_dir": output_dir,
560
  "num_train_epochs": training_config.get("num_train_epochs", 3),
 
626
  help="Dataset name or path")
627
  parser.add_argument("--output_dir", type=str, default=None,
628
  help="Output directory for the fine-tuned model")
629
+ parser.add_argument("--use_flash_attention", action="store_true",
630
+ help="Use Flash Attention if available")
631
 
632
  args = parser.parse_args()
633