George-API commited on
Commit
69ba4cd
·
verified ·
1 Parent(s): 6bbd6b2

Upload run_cloud_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_cloud_training.py +45 -7
run_cloud_training.py CHANGED
@@ -56,6 +56,16 @@ except Exception as e:
56
  logger.warning(f"Failed to install flash-attention: {e}")
57
  logger.info("Continuing without flash-attention")
58
 
 
 
 
 
 
 
 
 
 
 
59
  # Check if tensorboard is available
60
  try:
61
  import tensorboard
@@ -298,6 +308,8 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
298
  Load the model in a safe way that works with Qwen models
299
  by trying different loading strategies.
300
  """
 
 
301
  try:
302
  logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
303
 
@@ -328,14 +340,30 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
328
  logger.info("Falling back to standard Hugging Face loading...")
329
 
330
  # We'll try two approaches with HF loading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
  # Approach 1: Using attn_implementation parameter (newer method)
333
  try:
334
- logger.info("Trying HF loading with attn_implementation parameter")
335
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
336
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
337
 
338
- # The proper way to disable flash attention in newer transformers
339
  model = AutoModelForCausalLM.from_pretrained(
340
  model_name,
341
  config=config,
@@ -343,9 +371,9 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
343
  torch_dtype=dtype or torch.float16,
344
  quantization_config=bnb_config,
345
  trust_remote_code=True,
346
- attn_implementation="eager" # Use eager instead of flash_attention_2
347
  )
348
- logger.info("Model loaded successfully with HF using attn_implementation='eager'")
349
  return model, tokenizer
350
 
351
  except Exception as e:
@@ -385,9 +413,19 @@ def train(config_path, dataset_name, output_dir):
385
  lora_config = config.get("lora_config", {})
386
  dataset_config = config.get("dataset_config", {})
387
 
388
- # Override flash attention setting to disable it
389
- hardware_config["use_flash_attention"] = False
390
- logger.info("Flash attention has been DISABLED due to GPU compatibility issues")
 
 
 
 
 
 
 
 
 
 
391
 
392
  # Verify this is training phase only
393
  training_phase_only = dataset_config.get("training_phase_only", True)
 
56
  logger.warning(f"Failed to install flash-attention: {e}")
57
  logger.info("Continuing without flash-attention")
58
 
59
+ # Check if flash attention was successfully installed
60
+ flash_attention_available = False
61
+ try:
62
+ import flash_attn
63
+ flash_attention_available = True
64
+ logger.info(f"Flash Attention will be used (version: {flash_attn.__version__})")
65
+ # We'll handle flash attention configuration during model loading
66
+ except ImportError:
67
+ logger.info("Flash Attention not available, will use standard attention mechanism")
68
+
69
  # Check if tensorboard is available
70
  try:
71
  import tensorboard
 
308
  Load the model in a safe way that works with Qwen models
309
  by trying different loading strategies.
310
  """
311
+ global flash_attention_available
312
+
313
  try:
314
  logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
315
 
 
340
  logger.info("Falling back to standard Hugging Face loading...")
341
 
342
  # We'll try two approaches with HF loading
343
+ attn_params = {}
344
+
345
+ # If flash attention is available, try to use it
346
+ if flash_attention_available:
347
+ logger.info("Flash Attention is available - setting appropriate parameters")
348
+ # For newer models that support attn_implementation parameter
349
+ attn_params = {"attn_implementation": "eager"} # Default to eager for compatibility
350
+
351
+ # Try to use flash attention if available
352
+ try:
353
+ # Try importing flash attention to confirm it's available
354
+ import flash_attn
355
+ logger.info(f"Using Flash Attention version {flash_attn.__version__}")
356
+ attn_params = {"attn_implementation": "flash_attention_2"}
357
+ except Exception as flash_error:
358
+ logger.warning(f"Flash Attention import failed: {flash_error}")
359
 
360
  # Approach 1: Using attn_implementation parameter (newer method)
361
  try:
362
+ logger.info(f"Trying HF loading with attention parameters: {attn_params}")
363
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
364
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
365
 
366
+ # The proper way to set attention implementation in newer transformers
367
  model = AutoModelForCausalLM.from_pretrained(
368
  model_name,
369
  config=config,
 
371
  torch_dtype=dtype or torch.float16,
372
  quantization_config=bnb_config,
373
  trust_remote_code=True,
374
+ **attn_params
375
  )
376
+ logger.info(f"Model loaded successfully with HF using attention parameters: {attn_params}")
377
  return model, tokenizer
378
 
379
  except Exception as e:
 
413
  lora_config = config.get("lora_config", {})
414
  dataset_config = config.get("dataset_config", {})
415
 
416
+ # Update flash attention setting based on availability
417
+ global flash_attention_available
418
+ if flash_attention_available:
419
+ logger.info("Flash Attention is available - updating configuration")
420
+ # If flash attention is available, set attn_implementation to flash_attention_2
421
+ hardware_config["attn_implementation"] = "flash_attention_2"
422
+ else:
423
+ logger.info("Flash Attention not available - setting to eager attention")
424
+ hardware_config["attn_implementation"] = "eager"
425
+
426
+ # Override flash attention setting to disable it if there are compatibility issues
427
+ os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
428
+ logger.info("Flash attention has been DISABLED globally via environment variable")
429
 
430
  # Verify this is training phase only
431
  training_phase_only = dataset_config.get("training_phase_only", True)