Spaces:
Sleeping
Sleeping
Upload run_cloud_training.py with huggingface_hub
Browse files- run_cloud_training.py +92 -3
run_cloud_training.py
CHANGED
@@ -14,6 +14,7 @@ import argparse
|
|
14 |
import numpy as np
|
15 |
from dotenv import load_dotenv
|
16 |
import torch
|
|
|
17 |
from datasets import load_dataset
|
18 |
import transformers
|
19 |
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM, AutoConfig
|
@@ -26,6 +27,21 @@ os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
|
26 |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
27 |
os.environ["XFORMERS_DISABLED"] = "1"
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
# Configure PyTorch memory allocator for better memory management
|
30 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
31 |
|
@@ -391,6 +407,41 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
391 |
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
392 |
os.environ["XFORMERS_DISABLED"] = "1"
|
393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
try:
|
395 |
logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
|
396 |
|
@@ -412,7 +463,9 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
412 |
max_seq_length=max_seq_length,
|
413 |
dtype=dtype,
|
414 |
quantization_config=bnb_config,
|
415 |
-
attn_implementation="eager" # Force eager attention
|
|
|
|
|
416 |
)
|
417 |
logger.info("Model loaded successfully with unsloth")
|
418 |
|
@@ -424,6 +477,8 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
424 |
model.config.use_flash_attention = False
|
425 |
if hasattr(model.config, 'use_flash_attention_2'):
|
426 |
model.config.use_flash_attention_2 = False
|
|
|
|
|
427 |
|
428 |
return model, tokenizer
|
429 |
|
@@ -432,7 +487,11 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
432 |
logger.info("Falling back to standard Hugging Face loading...")
|
433 |
|
434 |
# We'll try with HF loading
|
435 |
-
attn_params = {
|
|
|
|
|
|
|
|
|
436 |
|
437 |
# Approach 1: Using attn_implementation parameter (newer method)
|
438 |
try:
|
@@ -446,6 +505,8 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
446 |
config.use_flash_attention = False
|
447 |
if hasattr(config, 'use_flash_attention_2'):
|
448 |
config.use_flash_attention_2 = False
|
|
|
|
|
449 |
|
450 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
451 |
|
@@ -476,6 +537,8 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
476 |
config.use_flash_attention = False
|
477 |
if hasattr(config, 'use_flash_attention_2'):
|
478 |
config.use_flash_attention_2 = False
|
|
|
|
|
479 |
|
480 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
481 |
|
@@ -486,7 +549,10 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
486 |
device_map="auto",
|
487 |
torch_dtype=dtype or torch.float16,
|
488 |
quantization_config=bnb_config,
|
489 |
-
trust_remote_code=True
|
|
|
|
|
|
|
490 |
)
|
491 |
logger.info("Model loaded successfully with basic HF loading")
|
492 |
return model, tokenizer
|
@@ -513,6 +579,28 @@ def train(config_path, dataset_name, output_dir):
|
|
513 |
os.environ["XFORMERS_DISABLED"] = "1"
|
514 |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
# Update flash attention setting to always use eager
|
517 |
global flash_attention_available
|
518 |
flash_attention_available = False
|
@@ -521,6 +609,7 @@ def train(config_path, dataset_name, output_dir):
|
|
521 |
# Update hardware config to ensure eager attention
|
522 |
hardware_config["attn_implementation"] = "eager"
|
523 |
hardware_config["use_flash_attention"] = False
|
|
|
524 |
|
525 |
# Verify this is training phase only
|
526 |
training_phase_only = dataset_config.get("training_phase_only", True)
|
|
|
14 |
import numpy as np
|
15 |
from dotenv import load_dotenv
|
16 |
import torch
|
17 |
+
import sys
|
18 |
from datasets import load_dataset
|
19 |
import transformers
|
20 |
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM, AutoConfig
|
|
|
27 |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
28 |
os.environ["XFORMERS_DISABLED"] = "1"
|
29 |
|
30 |
+
# Completely disable xformers by removing it from sys.modules if it's loaded
|
31 |
+
if 'xformers' in sys.modules:
|
32 |
+
del sys.modules['xformers']
|
33 |
+
if 'xformers.ops' in sys.modules:
|
34 |
+
del sys.modules['xformers.ops']
|
35 |
+
|
36 |
+
# Patch transformers to prevent xformers import
|
37 |
+
def prevent_xformers_import(name, *args, **kwargs):
|
38 |
+
if 'xformers' in name:
|
39 |
+
raise ImportError(f"Import of {name} prevented")
|
40 |
+
return original_import(name, *args, **kwargs)
|
41 |
+
|
42 |
+
original_import = __import__
|
43 |
+
__builtins__['__import__'] = prevent_xformers_import
|
44 |
+
|
45 |
# Configure PyTorch memory allocator for better memory management
|
46 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
47 |
|
|
|
407 |
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
408 |
os.environ["XFORMERS_DISABLED"] = "1"
|
409 |
|
410 |
+
# Patch transformers attention implementation
|
411 |
+
try:
|
412 |
+
# Try to patch transformers attention implementation to avoid xformers
|
413 |
+
import transformers.models.llama.modeling_llama as llama_modeling
|
414 |
+
|
415 |
+
# Store original attention implementation
|
416 |
+
if not hasattr(llama_modeling, '_original_forward'):
|
417 |
+
# Only patch if not already patched
|
418 |
+
logger.info("Patching LLaMA attention implementation to avoid xformers")
|
419 |
+
|
420 |
+
# Store original implementation
|
421 |
+
if hasattr(llama_modeling.LlamaAttention, 'forward'):
|
422 |
+
llama_modeling._original_forward = llama_modeling.LlamaAttention.forward
|
423 |
+
|
424 |
+
# Define a new forward method that doesn't use xformers
|
425 |
+
def safe_attention_forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False):
|
426 |
+
logger.info("Using safe attention implementation (no xformers)")
|
427 |
+
|
428 |
+
# Force use_flash_attention to False
|
429 |
+
self._attn_implementation = "eager"
|
430 |
+
if hasattr(self, 'use_flash_attention'):
|
431 |
+
self.use_flash_attention = False
|
432 |
+
if hasattr(self, 'use_flash_attention_2'):
|
433 |
+
self.use_flash_attention_2 = False
|
434 |
+
|
435 |
+
# Call original implementation with flash attention disabled
|
436 |
+
return llama_modeling._original_forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
|
437 |
+
|
438 |
+
# Replace the forward method
|
439 |
+
llama_modeling.LlamaAttention.forward = safe_attention_forward
|
440 |
+
logger.info("Successfully patched LLaMA attention implementation")
|
441 |
+
except Exception as e:
|
442 |
+
logger.warning(f"Failed to patch attention implementation: {e}")
|
443 |
+
logger.info("Will try to proceed with standard loading")
|
444 |
+
|
445 |
try:
|
446 |
logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
|
447 |
|
|
|
463 |
max_seq_length=max_seq_length,
|
464 |
dtype=dtype,
|
465 |
quantization_config=bnb_config,
|
466 |
+
attn_implementation="eager", # Force eager attention
|
467 |
+
use_flash_attention=False, # Explicitly disable flash attention
|
468 |
+
use_xformers_attention=False # Explicitly disable xformers
|
469 |
)
|
470 |
logger.info("Model loaded successfully with unsloth")
|
471 |
|
|
|
477 |
model.config.use_flash_attention = False
|
478 |
if hasattr(model.config, 'use_flash_attention_2'):
|
479 |
model.config.use_flash_attention_2 = False
|
480 |
+
if hasattr(model.config, 'use_xformers_attention'):
|
481 |
+
model.config.use_xformers_attention = False
|
482 |
|
483 |
return model, tokenizer
|
484 |
|
|
|
487 |
logger.info("Falling back to standard Hugging Face loading...")
|
488 |
|
489 |
# We'll try with HF loading
|
490 |
+
attn_params = {
|
491 |
+
"attn_implementation": "eager", # Always use eager
|
492 |
+
"use_flash_attention": False, # Explicitly disable flash attention
|
493 |
+
"use_xformers_attention": False # Explicitly disable xformers
|
494 |
+
}
|
495 |
|
496 |
# Approach 1: Using attn_implementation parameter (newer method)
|
497 |
try:
|
|
|
505 |
config.use_flash_attention = False
|
506 |
if hasattr(config, 'use_flash_attention_2'):
|
507 |
config.use_flash_attention_2 = False
|
508 |
+
if hasattr(config, 'use_xformers_attention'):
|
509 |
+
config.use_xformers_attention = False
|
510 |
|
511 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
512 |
|
|
|
537 |
config.use_flash_attention = False
|
538 |
if hasattr(config, 'use_flash_attention_2'):
|
539 |
config.use_flash_attention_2 = False
|
540 |
+
if hasattr(config, 'use_xformers_attention'):
|
541 |
+
config.use_xformers_attention = False
|
542 |
|
543 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
544 |
|
|
|
549 |
device_map="auto",
|
550 |
torch_dtype=dtype or torch.float16,
|
551 |
quantization_config=bnb_config,
|
552 |
+
trust_remote_code=True,
|
553 |
+
attn_implementation="eager",
|
554 |
+
use_flash_attention=False,
|
555 |
+
use_xformers_attention=False
|
556 |
)
|
557 |
logger.info("Model loaded successfully with basic HF loading")
|
558 |
return model, tokenizer
|
|
|
579 |
os.environ["XFORMERS_DISABLED"] = "1"
|
580 |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
581 |
|
582 |
+
# Monkey patch torch.nn.functional to disable memory_efficient_attention
|
583 |
+
try:
|
584 |
+
import torch.nn.functional as F
|
585 |
+
if hasattr(F, 'scaled_dot_product_attention'):
|
586 |
+
logger.info("Monkey patching torch.nn.functional.scaled_dot_product_attention")
|
587 |
+
original_sdpa = F.scaled_dot_product_attention
|
588 |
+
|
589 |
+
def safe_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
590 |
+
# Force disable memory efficient attention
|
591 |
+
logger.info("Using safe scaled_dot_product_attention (no xformers)")
|
592 |
+
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
593 |
+
|
594 |
+
F.scaled_dot_product_attention = safe_sdpa
|
595 |
+
except Exception as e:
|
596 |
+
logger.warning(f"Failed to patch scaled_dot_product_attention: {e}")
|
597 |
+
|
598 |
+
# Completely remove xformers from sys.modules if it's loaded
|
599 |
+
for module_name in list(sys.modules.keys()):
|
600 |
+
if 'xformers' in module_name:
|
601 |
+
logger.info(f"Removing {module_name} from sys.modules")
|
602 |
+
del sys.modules[module_name]
|
603 |
+
|
604 |
# Update flash attention setting to always use eager
|
605 |
global flash_attention_available
|
606 |
flash_attention_available = False
|
|
|
609 |
# Update hardware config to ensure eager attention
|
610 |
hardware_config["attn_implementation"] = "eager"
|
611 |
hardware_config["use_flash_attention"] = False
|
612 |
+
hardware_config["use_xformers_attention"] = False
|
613 |
|
614 |
# Verify this is training phase only
|
615 |
training_phase_only = dataset_config.get("training_phase_only", True)
|