George-API commited on
Commit
00a06ef
·
verified ·
1 Parent(s): 2457cec

Upload run_cloud_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_cloud_training.py +95 -67
run_cloud_training.py CHANGED
@@ -21,12 +21,14 @@ from transformers.data.data_collator import DataCollatorMixin
21
  from peft import LoraConfig
22
  from unsloth import FastLanguageModel
23
 
 
 
 
 
 
24
  # Configure PyTorch memory allocator for better memory management
25
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
26
 
27
- # Disable flash attention globally
28
- os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
29
-
30
  # Configure logging first
31
  logging.basicConfig(
32
  level=logging.INFO,
@@ -211,7 +213,7 @@ class PreTokenizedCollator(DataCollatorMixin):
211
  """
212
  def __init__(self, pad_token_id=0, tokenizer=None):
213
  self.pad_token_id = pad_token_id
214
- self.tokenizer = tokenizer # Keep a reference to the tokenizer for string conversion
215
 
216
  def __call__(self, features):
217
  # Print a sample feature to understand structure
@@ -221,66 +223,73 @@ class PreTokenizedCollator(DataCollatorMixin):
221
  # Extract input_ids from conversations if needed
222
  processed_features = []
223
  for feature in features:
 
 
 
 
 
 
224
  # If input_ids is not directly available, try to extract from conversations
225
  if 'input_ids' not in feature and 'conversations' in feature:
226
  # Extract from conversations based on your dataset structure
227
  conversations = feature['conversations']
228
 
229
- # Debug the conversations structure
230
- logger.info(f"Conversations type: {type(conversations)}")
231
- if isinstance(conversations, list) and len(conversations) > 0:
232
- logger.info(f"First conversation type: {type(conversations[0])}")
233
- logger.info(f"First conversation: {conversations[0]}")
234
 
235
  # Try different approaches to extract input_ids
236
  if isinstance(conversations, list) and len(conversations) > 0:
237
- # Case 1: If conversations is a list of dicts with 'content' field
238
- if isinstance(conversations[0], dict) and 'content' in conversations[0]:
239
- content = conversations[0]['content']
240
- logger.info(f"Found content field: {type(content)}")
241
-
242
- # If content is a string, tokenize it
243
- if isinstance(content, str) and self.tokenizer:
244
- logger.info(f"Tokenizing string content: {content[:50]}...")
245
- feature['input_ids'] = self.tokenizer.encode(content, add_special_tokens=False)
246
- # If content is already a list of integers, use it directly
247
- elif isinstance(content, list) and all(isinstance(x, int) for x in content):
248
- feature['input_ids'] = content
249
- # If content is already tokenized in some other format
250
- else:
251
- logger.warning(f"Unexpected content format: {type(content)}")
252
-
253
- # Case 2: If conversations is a list of dicts with 'input_ids' field
254
- elif isinstance(conversations[0], dict) and 'input_ids' in conversations[0]:
255
  feature['input_ids'] = conversations[0]['input_ids']
256
 
257
- # Case 3: If conversations itself contains the input_ids
258
  elif all(isinstance(x, int) for x in conversations):
259
  feature['input_ids'] = conversations
260
 
261
- # Case 4: If conversations is a list of strings
262
- elif all(isinstance(x, str) for x in conversations) and self.tokenizer:
263
- # Join all strings and tokenize
264
- full_text = " ".join(conversations)
265
- feature['input_ids'] = self.tokenizer.encode(full_text, add_special_tokens=False)
 
 
 
 
 
 
 
 
266
 
267
  # Ensure input_ids is a list of integers
268
  if 'input_ids' in feature:
269
- # If input_ids is a string, tokenize it
270
- if isinstance(feature['input_ids'], str) and self.tokenizer:
271
- logger.info(f"Converting string input_ids to tokens: {feature['input_ids'][:50]}...")
272
- feature['input_ids'] = self.tokenizer.encode(feature['input_ids'], add_special_tokens=False)
 
273
  # If input_ids is not a list, convert it
274
  elif not isinstance(feature['input_ids'], list):
275
  try:
276
  feature['input_ids'] = list(feature['input_ids'])
277
  except:
278
  logger.error(f"Could not convert input_ids to list: {type(feature['input_ids'])}")
 
 
 
 
279
 
280
  processed_features.append(feature)
281
 
282
  # If we still don't have input_ids, log an error
283
- if len(processed_features) > 0 and 'input_ids' not in processed_features[0]:
 
 
 
 
284
  logger.error(f"Could not find input_ids in features. Available keys: {list(processed_features[0].keys())}")
285
  if 'conversations' in processed_features[0]:
286
  logger.error(f"Conversations structure: {processed_features[0]['conversations'][:1]}")
@@ -344,6 +353,11 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
344
  """
345
  global flash_attention_available
346
 
 
 
 
 
 
347
  try:
348
  logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
349
 
@@ -364,37 +378,42 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
364
  model_name=model_name,
365
  max_seq_length=max_seq_length,
366
  dtype=dtype,
367
- quantization_config=bnb_config
 
368
  )
369
  logger.info("Model loaded successfully with unsloth")
 
 
 
 
 
 
 
 
 
 
370
  return model, tokenizer
371
 
372
  except Exception as e:
373
  logger.warning(f"Unsloth loading failed: {e}")
374
  logger.info("Falling back to standard Hugging Face loading...")
375
 
376
- # We'll try two approaches with HF loading
377
- attn_params = {}
378
-
379
- # If flash attention is available, try to use it
380
- if flash_attention_available:
381
- logger.info("Flash Attention is available - setting appropriate parameters")
382
- # For newer models that support attn_implementation parameter
383
- attn_params = {"attn_implementation": "eager"} # Default to eager for compatibility
384
-
385
- # Try to use flash attention if available
386
- try:
387
- # Try importing flash attention to confirm it's available
388
- import flash_attn
389
- logger.info(f"Using Flash Attention version {flash_attn.__version__}")
390
- attn_params = {"attn_implementation": "flash_attention_2"}
391
- except Exception as flash_error:
392
- logger.warning(f"Flash Attention import failed: {flash_error}")
393
 
394
  # Approach 1: Using attn_implementation parameter (newer method)
395
  try:
396
  logger.info(f"Trying HF loading with attention parameters: {attn_params}")
397
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
398
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
399
 
400
  # The proper way to set attention implementation in newer transformers
@@ -416,6 +435,15 @@ def load_model_safely(model_name, max_seq_length, dtype=None):
416
 
417
  # Approach 2: Complete fallback with minimal parameters
418
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
419
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
420
 
421
  # Most basic loading without any attention parameters
@@ -447,19 +475,19 @@ def train(config_path, dataset_name, output_dir):
447
  lora_config = config.get("lora_config", {})
448
  dataset_config = config.get("dataset_config", {})
449
 
450
- # Update flash attention setting based on availability
 
 
 
 
 
451
  global flash_attention_available
452
- if flash_attention_available:
453
- logger.info("Flash Attention is available - updating configuration")
454
- # If flash attention is available, set attn_implementation to flash_attention_2
455
- hardware_config["attn_implementation"] = "flash_attention_2"
456
- else:
457
- logger.info("Flash Attention not available - setting to eager attention")
458
- hardware_config["attn_implementation"] = "eager"
459
 
460
- # Override flash attention setting to disable it if there are compatibility issues
461
- os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
462
- logger.info("Flash attention has been DISABLED globally via environment variable")
463
 
464
  # Verify this is training phase only
465
  training_phase_only = dataset_config.get("training_phase_only", True)
 
21
  from peft import LoraConfig
22
  from unsloth import FastLanguageModel
23
 
24
+ # Disable all attention optimizations that might cause issues
25
+ os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
26
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
27
+ os.environ["XFORMERS_DISABLED"] = "1"
28
+
29
  # Configure PyTorch memory allocator for better memory management
30
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
31
 
 
 
 
32
  # Configure logging first
33
  logging.basicConfig(
34
  level=logging.INFO,
 
213
  """
214
  def __init__(self, pad_token_id=0, tokenizer=None):
215
  self.pad_token_id = pad_token_id
216
+ self.tokenizer = tokenizer # Keep a reference to the tokenizer for debugging only
217
 
218
  def __call__(self, features):
219
  # Print a sample feature to understand structure
 
223
  # Extract input_ids from conversations if needed
224
  processed_features = []
225
  for feature in features:
226
+ # If input_ids is directly available, use it without tokenization
227
+ if 'input_ids' in feature and isinstance(feature['input_ids'], list):
228
+ # Already tokenized, no processing needed
229
+ processed_features.append(feature)
230
+ continue
231
+
232
  # If input_ids is not directly available, try to extract from conversations
233
  if 'input_ids' not in feature and 'conversations' in feature:
234
  # Extract from conversations based on your dataset structure
235
  conversations = feature['conversations']
236
 
237
+ # Debug the conversations structure (only for first batch)
238
+ if len(processed_features) == 0:
239
+ logger.info(f"Conversations type: {type(conversations)}")
240
+ if isinstance(conversations, list) and len(conversations) > 0:
241
+ logger.info(f"First conversation type: {type(conversations[0])}")
242
 
243
  # Try different approaches to extract input_ids
244
  if isinstance(conversations, list) and len(conversations) > 0:
245
+ # Case 1: If conversations is a list of dicts with 'input_ids' field (pre-tokenized)
246
+ if isinstance(conversations[0], dict) and 'input_ids' in conversations[0]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  feature['input_ids'] = conversations[0]['input_ids']
248
 
249
+ # Case 2: If conversations itself contains the input_ids (pre-tokenized)
250
  elif all(isinstance(x, int) for x in conversations):
251
  feature['input_ids'] = conversations
252
 
253
+ # Case 3: If conversations is a list of dicts with 'content' field
254
+ # This should be avoided for pre-tokenized datasets
255
+ elif isinstance(conversations[0], dict) and 'content' in conversations[0]:
256
+ content = conversations[0]['content']
257
+
258
+ # If content is already a list of integers, use it directly
259
+ if isinstance(content, list) and all(isinstance(x, int) for x in content):
260
+ feature['input_ids'] = content
261
+ # AVOID TOKENIZATION: Log warning if content is a string
262
+ elif isinstance(content, str):
263
+ logger.warning("Found string content in pre-tokenized dataset. This should not happen.")
264
+ logger.warning("Skipping this example to avoid tokenization.")
265
+ continue
266
 
267
  # Ensure input_ids is a list of integers
268
  if 'input_ids' in feature:
269
+ # AVOID TOKENIZATION: Skip string input_ids
270
+ if isinstance(feature['input_ids'], str):
271
+ logger.warning("Found string input_ids in pre-tokenized dataset. This should not happen.")
272
+ logger.warning("Skipping this example to avoid tokenization.")
273
+ continue
274
  # If input_ids is not a list, convert it
275
  elif not isinstance(feature['input_ids'], list):
276
  try:
277
  feature['input_ids'] = list(feature['input_ids'])
278
  except:
279
  logger.error(f"Could not convert input_ids to list: {type(feature['input_ids'])}")
280
+ continue
281
+ else:
282
+ logger.warning("No input_ids found in this example. Skipping.")
283
+ continue
284
 
285
  processed_features.append(feature)
286
 
287
  # If we still don't have input_ids, log an error
288
+ if len(processed_features) == 0:
289
+ logger.error("No valid examples found in batch. Check dataset format.")
290
+ raise ValueError("No valid examples found. Please check dataset structure.")
291
+
292
+ if 'input_ids' not in processed_features[0]:
293
  logger.error(f"Could not find input_ids in features. Available keys: {list(processed_features[0].keys())}")
294
  if 'conversations' in processed_features[0]:
295
  logger.error(f"Conversations structure: {processed_features[0]['conversations'][:1]}")
 
353
  """
354
  global flash_attention_available
355
 
356
+ # Force disable flash attention and xformers
357
+ flash_attention_available = False
358
+ os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
359
+ os.environ["XFORMERS_DISABLED"] = "1"
360
+
361
  try:
362
  logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
363
 
 
378
  model_name=model_name,
379
  max_seq_length=max_seq_length,
380
  dtype=dtype,
381
+ quantization_config=bnb_config,
382
+ attn_implementation="eager" # Force eager attention
383
  )
384
  logger.info("Model loaded successfully with unsloth")
385
+
386
+ # Explicitly disable flash attention in model config
387
+ if hasattr(model, 'config'):
388
+ if hasattr(model.config, 'attn_implementation'):
389
+ model.config.attn_implementation = "eager"
390
+ if hasattr(model.config, 'use_flash_attention'):
391
+ model.config.use_flash_attention = False
392
+ if hasattr(model.config, 'use_flash_attention_2'):
393
+ model.config.use_flash_attention_2 = False
394
+
395
  return model, tokenizer
396
 
397
  except Exception as e:
398
  logger.warning(f"Unsloth loading failed: {e}")
399
  logger.info("Falling back to standard Hugging Face loading...")
400
 
401
+ # We'll try with HF loading
402
+ attn_params = {"attn_implementation": "eager"} # Always use eager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
  # Approach 1: Using attn_implementation parameter (newer method)
405
  try:
406
  logger.info(f"Trying HF loading with attention parameters: {attn_params}")
407
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
408
+
409
+ # Disable flash attention in config
410
+ if hasattr(config, 'attn_implementation'):
411
+ config.attn_implementation = "eager"
412
+ if hasattr(config, 'use_flash_attention'):
413
+ config.use_flash_attention = False
414
+ if hasattr(config, 'use_flash_attention_2'):
415
+ config.use_flash_attention_2 = False
416
+
417
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
418
 
419
  # The proper way to set attention implementation in newer transformers
 
435
 
436
  # Approach 2: Complete fallback with minimal parameters
437
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
438
+
439
+ # Disable flash attention in config
440
+ if hasattr(config, 'attn_implementation'):
441
+ config.attn_implementation = "eager"
442
+ if hasattr(config, 'use_flash_attention'):
443
+ config.use_flash_attention = False
444
+ if hasattr(config, 'use_flash_attention_2'):
445
+ config.use_flash_attention_2 = False
446
+
447
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
448
 
449
  # Most basic loading without any attention parameters
 
475
  lora_config = config.get("lora_config", {})
476
  dataset_config = config.get("dataset_config", {})
477
 
478
+ # Force disable flash attention and xformers
479
+ os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
480
+ os.environ["XFORMERS_DISABLED"] = "1"
481
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
482
+
483
+ # Update flash attention setting to always use eager
484
  global flash_attention_available
485
+ flash_attention_available = False
486
+ logger.info("Flash Attention has been DISABLED globally")
 
 
 
 
 
487
 
488
+ # Update hardware config to ensure eager attention
489
+ hardware_config["attn_implementation"] = "eager"
490
+ hardware_config["use_flash_attention"] = False
491
 
492
  # Verify this is training phase only
493
  training_phase_only = dataset_config.get("training_phase_only", True)