Spaces:
Sleeping
Sleeping
Upload run_cloud_training.py with huggingface_hub
Browse files- run_cloud_training.py +57 -193
run_cloud_training.py
CHANGED
@@ -401,147 +401,62 @@ def remove_training_marker():
|
|
401 |
os.remove("TRAINING_ACTIVE")
|
402 |
logger.info("Removed training active marker")
|
403 |
|
404 |
-
def load_model_safely(model_name, max_seq_length, dtype=None):
|
405 |
"""
|
406 |
-
Load the model
|
407 |
-
by trying different loading strategies.
|
408 |
"""
|
409 |
-
|
410 |
|
411 |
-
#
|
412 |
-
|
413 |
-
|
414 |
-
|
|
|
|
|
|
|
|
|
415 |
|
416 |
-
#
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
logger.info("Patching LLaMA attention implementation to avoid xformers")
|
425 |
-
|
426 |
-
# Store original implementation
|
427 |
-
if hasattr(llama_modeling.LlamaAttention, 'forward'):
|
428 |
-
llama_modeling._original_forward = llama_modeling.LlamaAttention.forward
|
429 |
-
|
430 |
-
# Define a new forward method that doesn't use xformers
|
431 |
-
def safe_attention_forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False):
|
432 |
-
logger.info("Using safe attention implementation (no xformers)")
|
433 |
-
|
434 |
-
# Force use_flash_attention to False
|
435 |
-
self._attn_implementation = "eager"
|
436 |
-
if hasattr(self, 'use_flash_attention'):
|
437 |
-
self.use_flash_attention = False
|
438 |
-
if hasattr(self, 'use_flash_attention_2'):
|
439 |
-
self.use_flash_attention_2 = False
|
440 |
-
|
441 |
-
# Call original implementation with flash attention disabled
|
442 |
-
return llama_modeling._original_forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
|
443 |
-
|
444 |
-
# Replace the forward method
|
445 |
-
llama_modeling.LlamaAttention.forward = safe_attention_forward
|
446 |
-
logger.info("Successfully patched LLaMA attention implementation")
|
447 |
-
except Exception as e:
|
448 |
-
logger.warning(f"Failed to patch attention implementation: {e}")
|
449 |
-
logger.info("Will try to proceed with standard loading")
|
450 |
|
|
|
451 |
try:
|
452 |
-
logger.info(
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
bnb_4bit_quant_type="nf4",
|
460 |
-
bnb_4bit_use_double_quant=True
|
461 |
)
|
|
|
|
|
462 |
|
463 |
-
# First try loading with unsloth but without flash attention
|
464 |
-
try:
|
465 |
-
logger.info("Loading model with unsloth optimizations")
|
466 |
-
# Don't pass any flash attention parameters to unsloth
|
467 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
468 |
-
model_name=model_name,
|
469 |
-
max_seq_length=max_seq_length,
|
470 |
-
dtype=dtype,
|
471 |
-
quantization_config=bnb_config,
|
472 |
-
attn_implementation="eager" # Force eager attention
|
473 |
-
)
|
474 |
-
logger.info("Model loaded successfully with unsloth")
|
475 |
-
|
476 |
-
# Explicitly disable flash attention in model config
|
477 |
-
if hasattr(model, 'config'):
|
478 |
-
if hasattr(model.config, 'attn_implementation'):
|
479 |
-
model.config.attn_implementation = "eager"
|
480 |
-
|
481 |
-
return model, tokenizer
|
482 |
-
|
483 |
-
except Exception as e:
|
484 |
-
logger.warning(f"Unsloth loading failed: {e}")
|
485 |
-
logger.info("Falling back to standard Hugging Face loading...")
|
486 |
-
|
487 |
-
# We'll try with HF loading
|
488 |
-
attn_params = {
|
489 |
-
"attn_implementation": "eager" # Always use eager
|
490 |
-
}
|
491 |
-
|
492 |
-
# Approach 1: Using attn_implementation parameter (newer method)
|
493 |
-
try:
|
494 |
-
logger.info(f"Trying HF loading with attention parameters: {attn_params}")
|
495 |
-
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
496 |
-
|
497 |
-
# Disable flash attention in config
|
498 |
-
if hasattr(config, 'attn_implementation'):
|
499 |
-
config.attn_implementation = "eager"
|
500 |
-
|
501 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
502 |
-
|
503 |
-
# The proper way to set attention implementation in newer transformers
|
504 |
-
model = AutoModelForCausalLM.from_pretrained(
|
505 |
-
model_name,
|
506 |
-
config=config,
|
507 |
-
device_map="auto",
|
508 |
-
torch_dtype=dtype or torch.float16,
|
509 |
-
quantization_config=bnb_config,
|
510 |
-
trust_remote_code=True,
|
511 |
-
**attn_params
|
512 |
-
)
|
513 |
-
logger.info(f"Model loaded successfully with HF using attention parameters: {attn_params}")
|
514 |
-
return model, tokenizer
|
515 |
-
|
516 |
-
except Exception as e:
|
517 |
-
logger.warning(f"HF loading with attn_implementation failed: {e}")
|
518 |
-
logger.info("Trying fallback method...")
|
519 |
-
|
520 |
-
# Approach 2: Complete fallback with minimal parameters
|
521 |
-
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
522 |
-
|
523 |
-
# Disable flash attention in config
|
524 |
-
if hasattr(config, 'attn_implementation'):
|
525 |
-
config.attn_implementation = "eager"
|
526 |
-
|
527 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
528 |
-
|
529 |
-
# Most basic loading without any attention parameters
|
530 |
-
model = AutoModelForCausalLM.from_pretrained(
|
531 |
-
model_name,
|
532 |
-
config=config,
|
533 |
-
device_map="auto",
|
534 |
-
torch_dtype=dtype or torch.float16,
|
535 |
-
quantization_config=bnb_config,
|
536 |
-
trust_remote_code=True,
|
537 |
-
attn_implementation="eager"
|
538 |
-
)
|
539 |
-
logger.info("Model loaded successfully with basic HF loading")
|
540 |
-
return model, tokenizer
|
541 |
-
|
542 |
except Exception as e:
|
543 |
-
logger.
|
544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
|
546 |
def train(config_path, dataset_name, output_dir):
|
547 |
"""Main training function - RESEARCH TRAINING PHASE ONLY"""
|
@@ -556,50 +471,6 @@ def train(config_path, dataset_name, output_dir):
|
|
556 |
lora_config = config.get("lora_config", {})
|
557 |
dataset_config = config.get("dataset_config", {})
|
558 |
|
559 |
-
# Force disable flash attention and xformers
|
560 |
-
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
561 |
-
os.environ["XFORMERS_DISABLED"] = "1"
|
562 |
-
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
563 |
-
|
564 |
-
# Monkey patch torch.nn.functional to disable memory_efficient_attention
|
565 |
-
try:
|
566 |
-
import torch.nn.functional as F
|
567 |
-
if hasattr(F, 'scaled_dot_product_attention'):
|
568 |
-
logger.info("Monkey patching torch.nn.functional.scaled_dot_product_attention")
|
569 |
-
original_sdpa = F.scaled_dot_product_attention
|
570 |
-
|
571 |
-
def safe_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
572 |
-
# Force disable memory efficient attention
|
573 |
-
logger.info("Using safe scaled_dot_product_attention (no xformers)")
|
574 |
-
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
575 |
-
|
576 |
-
F.scaled_dot_product_attention = safe_sdpa
|
577 |
-
except Exception as e:
|
578 |
-
logger.warning(f"Failed to patch scaled_dot_product_attention: {e}")
|
579 |
-
|
580 |
-
# Completely remove xformers from sys.modules if it's loaded
|
581 |
-
for module_name in list(sys.modules.keys()):
|
582 |
-
if 'xformers' in module_name:
|
583 |
-
logger.info(f"Removing {module_name} from sys.modules")
|
584 |
-
del sys.modules[module_name]
|
585 |
-
|
586 |
-
# Update flash attention setting to always use eager
|
587 |
-
global flash_attention_available
|
588 |
-
flash_attention_available = False
|
589 |
-
logger.info("Flash Attention has been DISABLED globally")
|
590 |
-
|
591 |
-
# Update hardware config to ensure eager attention
|
592 |
-
hardware_config["attn_implementation"] = "eager"
|
593 |
-
|
594 |
-
# Verify this is training phase only
|
595 |
-
training_phase_only = dataset_config.get("training_phase_only", True)
|
596 |
-
if not training_phase_only:
|
597 |
-
logger.warning("This script is meant for research training phase only")
|
598 |
-
logger.warning("Setting training_phase_only=True")
|
599 |
-
|
600 |
-
# Verify dataset is pre-tokenized
|
601 |
-
logger.info("IMPORTANT: Using pre-tokenized dataset - No tokenization will be performed")
|
602 |
-
|
603 |
# Set the output directory
|
604 |
output_dir = output_dir or training_config.get("output_dir", "fine_tuned_model")
|
605 |
os.makedirs(output_dir, exist_ok=True)
|
@@ -628,8 +499,8 @@ def train(config_path, dataset_name, output_dir):
|
|
628 |
)
|
629 |
tokenizer.pad_token = tokenizer.eos_token
|
630 |
|
631 |
-
# Initialize model
|
632 |
-
logger.info("Initializing model
|
633 |
max_seq_length = training_config.get("max_seq_length", 2048)
|
634 |
|
635 |
# Create LoRA config directly
|
@@ -642,29 +513,21 @@ def train(config_path, dataset_name, output_dir):
|
|
642 |
target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
|
643 |
)
|
644 |
|
|
|
|
|
|
|
645 |
# Initialize model with our safe loading function
|
646 |
-
logger.info("Loading pre-quantized model
|
647 |
dtype = torch.float16 if hardware_config.get("fp16", True) else None
|
648 |
-
|
649 |
-
# Force eager attention implementation
|
650 |
-
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
651 |
-
logger.info("Flash attention has been DISABLED globally via environment variable")
|
652 |
-
|
653 |
-
# Update hardware config to ensure eager attention
|
654 |
-
hardware_config["attn_implementation"] = "eager"
|
655 |
-
|
656 |
-
model, tokenizer = load_model_safely(model_name, max_seq_length, dtype)
|
657 |
|
658 |
# Disable generation capabilities for research training
|
659 |
logger.info("Disabling generation capabilities - Research training only")
|
660 |
model.config.is_decoder = False
|
661 |
model.config.task_specific_params = None
|
662 |
|
663 |
-
#
|
664 |
logger.info("Applying LoRA to model")
|
665 |
-
|
666 |
-
# Skip unsloth's method and go directly to PEFT
|
667 |
-
logger.info("Using standard PEFT method to apply LoRA")
|
668 |
from peft import get_peft_model
|
669 |
model = get_peft_model(model, lora_config_obj)
|
670 |
logger.info("Successfully applied LoRA with standard PEFT")
|
@@ -692,7 +555,6 @@ def train(config_path, dataset_name, output_dir):
|
|
692 |
logger.warning("No reporting backends available - training metrics won't be logged")
|
693 |
|
694 |
# Set up training arguments with correct parameters
|
695 |
-
# Extract only the valid parameters from hardware_config
|
696 |
training_args_dict = {
|
697 |
"output_dir": output_dir,
|
698 |
"num_train_epochs": training_config.get("num_train_epochs", 3),
|
@@ -764,6 +626,8 @@ if __name__ == "__main__":
|
|
764 |
help="Dataset name or path")
|
765 |
parser.add_argument("--output_dir", type=str, default=None,
|
766 |
help="Output directory for the fine-tuned model")
|
|
|
|
|
767 |
|
768 |
args = parser.parse_args()
|
769 |
|
|
|
401 |
os.remove("TRAINING_ACTIVE")
|
402 |
logger.info("Removed training active marker")
|
403 |
|
404 |
+
def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attention=False):
|
405 |
"""
|
406 |
+
Load the model with appropriate attention settings based on hardware capability
|
|
|
407 |
"""
|
408 |
+
logger.info(f"Loading model: {model_name}")
|
409 |
|
410 |
+
# Create BitsAndBytesConfig for 4-bit quantization
|
411 |
+
from transformers import BitsAndBytesConfig
|
412 |
+
bnb_config = BitsAndBytesConfig(
|
413 |
+
load_in_4bit=True,
|
414 |
+
bnb_4bit_compute_dtype=torch.float16,
|
415 |
+
bnb_4bit_quant_type="nf4",
|
416 |
+
bnb_4bit_use_double_quant=True
|
417 |
+
)
|
418 |
|
419 |
+
# Determine appropriate attention implementation
|
420 |
+
attn_implementation = "sdpa" # Default to PyTorch's scaled dot product attention
|
421 |
+
|
422 |
+
if use_flash_attention and flash_attention_available:
|
423 |
+
logger.info("Using Flash Attention for faster training")
|
424 |
+
attn_implementation = "flash_attention_2"
|
425 |
+
else:
|
426 |
+
logger.info("Using standard attention mechanism (sdpa)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
|
428 |
+
# Try loading with unsloth
|
429 |
try:
|
430 |
+
logger.info("Loading model with unsloth optimizations")
|
431 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
432 |
+
model_name=model_name,
|
433 |
+
max_seq_length=max_seq_length,
|
434 |
+
dtype=dtype,
|
435 |
+
quantization_config=bnb_config,
|
436 |
+
attn_implementation=attn_implementation
|
|
|
|
|
437 |
)
|
438 |
+
logger.info("Model loaded successfully with unsloth")
|
439 |
+
return model, tokenizer
|
440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
except Exception as e:
|
442 |
+
logger.warning(f"Unsloth loading failed: {e}")
|
443 |
+
logger.info("Falling back to standard Hugging Face loading...")
|
444 |
+
|
445 |
+
# Fallback to standard HF loading
|
446 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
447 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
448 |
+
|
449 |
+
model = AutoModelForCausalLM.from_pretrained(
|
450 |
+
model_name,
|
451 |
+
config=config,
|
452 |
+
device_map="auto",
|
453 |
+
torch_dtype=dtype or torch.float16,
|
454 |
+
quantization_config=bnb_config,
|
455 |
+
trust_remote_code=True,
|
456 |
+
attn_implementation=attn_implementation
|
457 |
+
)
|
458 |
+
logger.info("Model loaded successfully with standard HF loading")
|
459 |
+
return model, tokenizer
|
460 |
|
461 |
def train(config_path, dataset_name, output_dir):
|
462 |
"""Main training function - RESEARCH TRAINING PHASE ONLY"""
|
|
|
471 |
lora_config = config.get("lora_config", {})
|
472 |
dataset_config = config.get("dataset_config", {})
|
473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
# Set the output directory
|
475 |
output_dir = output_dir or training_config.get("output_dir", "fine_tuned_model")
|
476 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
499 |
)
|
500 |
tokenizer.pad_token = tokenizer.eos_token
|
501 |
|
502 |
+
# Initialize model
|
503 |
+
logger.info("Initializing model (preserving 4-bit quantization)")
|
504 |
max_seq_length = training_config.get("max_seq_length", 2048)
|
505 |
|
506 |
# Create LoRA config directly
|
|
|
513 |
target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
|
514 |
)
|
515 |
|
516 |
+
# Determine if we should use flash attention
|
517 |
+
use_flash_attention = hardware_config.get("use_flash_attention", False)
|
518 |
+
|
519 |
# Initialize model with our safe loading function
|
520 |
+
logger.info("Loading pre-quantized model")
|
521 |
dtype = torch.float16 if hardware_config.get("fp16", True) else None
|
522 |
+
model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
|
524 |
# Disable generation capabilities for research training
|
525 |
logger.info("Disabling generation capabilities - Research training only")
|
526 |
model.config.is_decoder = False
|
527 |
model.config.task_specific_params = None
|
528 |
|
529 |
+
# Apply LoRA to model
|
530 |
logger.info("Applying LoRA to model")
|
|
|
|
|
|
|
531 |
from peft import get_peft_model
|
532 |
model = get_peft_model(model, lora_config_obj)
|
533 |
logger.info("Successfully applied LoRA with standard PEFT")
|
|
|
555 |
logger.warning("No reporting backends available - training metrics won't be logged")
|
556 |
|
557 |
# Set up training arguments with correct parameters
|
|
|
558 |
training_args_dict = {
|
559 |
"output_dir": output_dir,
|
560 |
"num_train_epochs": training_config.get("num_train_epochs", 3),
|
|
|
626 |
help="Dataset name or path")
|
627 |
parser.add_argument("--output_dir", type=str, default=None,
|
628 |
help="Output directory for the fine-tuned model")
|
629 |
+
parser.add_argument("--use_flash_attention", action="store_true",
|
630 |
+
help="Use Flash Attention if available")
|
631 |
|
632 |
args = parser.parse_args()
|
633 |
|