Spaces:
Sleeping
Sleeping
Upload run_cloud_training.py with huggingface_hub
Browse files- run_cloud_training.py +45 -7
run_cloud_training.py
CHANGED
|
@@ -56,6 +56,16 @@ except Exception as e:
|
|
| 56 |
logger.warning(f"Failed to install flash-attention: {e}")
|
| 57 |
logger.info("Continuing without flash-attention")
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# Check if tensorboard is available
|
| 60 |
try:
|
| 61 |
import tensorboard
|
|
@@ -298,6 +308,8 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
| 298 |
Load the model in a safe way that works with Qwen models
|
| 299 |
by trying different loading strategies.
|
| 300 |
"""
|
|
|
|
|
|
|
| 301 |
try:
|
| 302 |
logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
|
| 303 |
|
|
@@ -328,14 +340,30 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
| 328 |
logger.info("Falling back to standard Hugging Face loading...")
|
| 329 |
|
| 330 |
# We'll try two approaches with HF loading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
# Approach 1: Using attn_implementation parameter (newer method)
|
| 333 |
try:
|
| 334 |
-
logger.info("Trying HF loading with
|
| 335 |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 336 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 337 |
|
| 338 |
-
# The proper way to
|
| 339 |
model = AutoModelForCausalLM.from_pretrained(
|
| 340 |
model_name,
|
| 341 |
config=config,
|
|
@@ -343,9 +371,9 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
|
|
| 343 |
torch_dtype=dtype or torch.float16,
|
| 344 |
quantization_config=bnb_config,
|
| 345 |
trust_remote_code=True,
|
| 346 |
-
|
| 347 |
)
|
| 348 |
-
logger.info("Model loaded successfully with HF using
|
| 349 |
return model, tokenizer
|
| 350 |
|
| 351 |
except Exception as e:
|
|
@@ -385,9 +413,19 @@ def train(config_path, dataset_name, output_dir):
|
|
| 385 |
lora_config = config.get("lora_config", {})
|
| 386 |
dataset_config = config.get("dataset_config", {})
|
| 387 |
|
| 388 |
-
#
|
| 389 |
-
|
| 390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
# Verify this is training phase only
|
| 393 |
training_phase_only = dataset_config.get("training_phase_only", True)
|
|
|
|
| 56 |
logger.warning(f"Failed to install flash-attention: {e}")
|
| 57 |
logger.info("Continuing without flash-attention")
|
| 58 |
|
| 59 |
+
# Check if flash attention was successfully installed
|
| 60 |
+
flash_attention_available = False
|
| 61 |
+
try:
|
| 62 |
+
import flash_attn
|
| 63 |
+
flash_attention_available = True
|
| 64 |
+
logger.info(f"Flash Attention will be used (version: {flash_attn.__version__})")
|
| 65 |
+
# We'll handle flash attention configuration during model loading
|
| 66 |
+
except ImportError:
|
| 67 |
+
logger.info("Flash Attention not available, will use standard attention mechanism")
|
| 68 |
+
|
| 69 |
# Check if tensorboard is available
|
| 70 |
try:
|
| 71 |
import tensorboard
|
|
|
|
| 308 |
Load the model in a safe way that works with Qwen models
|
| 309 |
by trying different loading strategies.
|
| 310 |
"""
|
| 311 |
+
global flash_attention_available
|
| 312 |
+
|
| 313 |
try:
|
| 314 |
logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
|
| 315 |
|
|
|
|
| 340 |
logger.info("Falling back to standard Hugging Face loading...")
|
| 341 |
|
| 342 |
# We'll try two approaches with HF loading
|
| 343 |
+
attn_params = {}
|
| 344 |
+
|
| 345 |
+
# If flash attention is available, try to use it
|
| 346 |
+
if flash_attention_available:
|
| 347 |
+
logger.info("Flash Attention is available - setting appropriate parameters")
|
| 348 |
+
# For newer models that support attn_implementation parameter
|
| 349 |
+
attn_params = {"attn_implementation": "eager"} # Default to eager for compatibility
|
| 350 |
+
|
| 351 |
+
# Try to use flash attention if available
|
| 352 |
+
try:
|
| 353 |
+
# Try importing flash attention to confirm it's available
|
| 354 |
+
import flash_attn
|
| 355 |
+
logger.info(f"Using Flash Attention version {flash_attn.__version__}")
|
| 356 |
+
attn_params = {"attn_implementation": "flash_attention_2"}
|
| 357 |
+
except Exception as flash_error:
|
| 358 |
+
logger.warning(f"Flash Attention import failed: {flash_error}")
|
| 359 |
|
| 360 |
# Approach 1: Using attn_implementation parameter (newer method)
|
| 361 |
try:
|
| 362 |
+
logger.info(f"Trying HF loading with attention parameters: {attn_params}")
|
| 363 |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 364 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 365 |
|
| 366 |
+
# The proper way to set attention implementation in newer transformers
|
| 367 |
model = AutoModelForCausalLM.from_pretrained(
|
| 368 |
model_name,
|
| 369 |
config=config,
|
|
|
|
| 371 |
torch_dtype=dtype or torch.float16,
|
| 372 |
quantization_config=bnb_config,
|
| 373 |
trust_remote_code=True,
|
| 374 |
+
**attn_params
|
| 375 |
)
|
| 376 |
+
logger.info(f"Model loaded successfully with HF using attention parameters: {attn_params}")
|
| 377 |
return model, tokenizer
|
| 378 |
|
| 379 |
except Exception as e:
|
|
|
|
| 413 |
lora_config = config.get("lora_config", {})
|
| 414 |
dataset_config = config.get("dataset_config", {})
|
| 415 |
|
| 416 |
+
# Update flash attention setting based on availability
|
| 417 |
+
global flash_attention_available
|
| 418 |
+
if flash_attention_available:
|
| 419 |
+
logger.info("Flash Attention is available - updating configuration")
|
| 420 |
+
# If flash attention is available, set attn_implementation to flash_attention_2
|
| 421 |
+
hardware_config["attn_implementation"] = "flash_attention_2"
|
| 422 |
+
else:
|
| 423 |
+
logger.info("Flash Attention not available - setting to eager attention")
|
| 424 |
+
hardware_config["attn_implementation"] = "eager"
|
| 425 |
+
|
| 426 |
+
# Override flash attention setting to disable it if there are compatibility issues
|
| 427 |
+
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
| 428 |
+
logger.info("Flash attention has been DISABLED globally via environment variable")
|
| 429 |
|
| 430 |
# Verify this is training phase only
|
| 431 |
training_phase_only = dataset_config.get("training_phase_only", True)
|