George-API commited on
Commit
75f9a64
·
verified ·
1 Parent(s): bf7bd7e

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. run_transformers_training.py +825 -622
run_transformers_training.py CHANGED
@@ -184,227 +184,291 @@ def load_configs(base_path):
184
  raise
185
 
186
  def parse_args():
187
- parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
188
- parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  return parser.parse_args()
190
 
191
  def load_model_and_tokenizer(config):
192
- """Load model and tokenizer with proper error handling and optimizations."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  try:
194
- if not unsloth_available:
195
- logger.error("Unsloth is required for training with pre-quantized model")
196
- logger.error("Please ensure unsloth is in requirements.txt")
197
- raise ImportError("Unsloth is required for this training setup")
198
-
199
- # Get model name correctly from config
200
- model_name = config.get("model_name") or config.get("model", {}).get("name")
201
- logger.info(f"Loading model: {model_name}")
202
-
203
- if not model_name:
204
- raise ValueError("Model name not found in configuration. Please check your transformers_config.json file.")
 
205
 
206
- logger.info("Using Unsloth optimizations with pre-quantized model")
207
-
208
- # First detect if we have a GPU
209
- if torch.cuda.is_available():
210
- gpu_count = torch.cuda.device_count()
211
- logger.info(f"Found {gpu_count} CUDA devices")
212
- else:
213
- logger.warning("No CUDA devices detected. Training will be slow on CPU!")
214
- gpu_count = 0
215
-
216
- # Set default dtype for better numerics
217
- if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
218
- # Use bfloat16 for Ampere or newer
219
- dtype = torch.bfloat16
220
- logger.info("Using bfloat16 precision (Ampere+ GPU)")
221
- elif torch.cuda.is_available():
222
- # Use float16 for older GPUs
223
- dtype = torch.float16
224
- logger.info("Using float16 precision (pre-Ampere GPU)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  else:
226
- # CPU, use default dtype
227
- dtype = None
228
- logger.info("Using default precision (CPU)")
229
-
230
- # Check for flash attention as the last dependency check
231
- use_flash_attention = config.get("use_flash_attention", True)
232
- if use_flash_attention and not find_spec("flash_attn"):
233
- logger.warning("flash-attn not found. Will continue without flash attention.")
234
- logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
235
- use_flash_attention = False
236
-
237
- # Set device map based on config or default to "auto"
238
- device_map = config.get("hardware", {}).get("hardware_setup", {}).get("device_map", "auto")
239
-
240
- # Calculate max memory settings if multiple GPUs are available
241
- max_memory = None
242
- if gpu_count > 1:
243
- memory_per_gpu = config.get("hardware", {}).get("specs", {}).get("vram_per_gpu", 24)
244
- max_memory = {i: f"{int(memory_per_gpu * 0.85)}GiB" for i in range(gpu_count)}
245
- max_memory["cpu"] = "64GiB" # Allow CPU offloading if needed
246
-
247
- # Load model with proper error handling for out-of-memory
248
- try:
249
- # Improved memory settings for multi-GPU setup
250
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
251
 
252
- model, tokenizer = FastLanguageModel.from_pretrained(
253
- model_name=model_name,
254
- max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
255
- dtype=dtype,
256
- device_map=device_map,
257
- max_memory=max_memory,
258
- # Don't explicitly use flash attention config here, let Unsloth handle it
259
  )
260
- except RuntimeError as e:
261
- if "CUDA out of memory" in str(e):
262
- logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.")
263
- raise
264
- else:
265
- # Try again with CPU placement to see if it's a memory issue
266
- logger.warning(f"Error loading model on default device: {str(e)}")
267
- logger.warning("Attempting to load with device_map='cpu' and no specific dtype")
268
- model, tokenizer = FastLanguageModel.from_pretrained(
269
- model_name=model_name,
270
- max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
271
- dtype=None,
272
- device_map={"": "cpu"},
273
- )
274
- logger.warning("Model loaded on CPU. Training will be very slow.")
275
-
276
- # Ensure model and optimizer init is on the same device
277
- logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}")
278
 
279
- # Apply Unsloth's training optimizations with config parameters
280
- unsloth_config = config.get("unsloth", {})
281
-
282
- # Get dropout value; if not explicitly zero, warn about performance implications
283
- lora_dropout = unsloth_config.get("dropout", 0.05)
284
- if lora_dropout > 0:
285
- logger.warning(f"Unsloth works best with dropout=0, but config has dropout={lora_dropout}")
286
- logger.warning("This will impact performance but training will still work")
287
- logger.warning("Consider setting dropout=0 in your config for better performance")
288
-
289
- # Apply optimizations
290
- model = FastLanguageModel.get_peft_model(
291
- model,
292
- r=unsloth_config.get("r", 32),
293
- target_modules=unsloth_config.get("target_modules",
294
- ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]),
295
- lora_alpha=unsloth_config.get("alpha", 16),
296
- lora_dropout=lora_dropout, # Using the value from config or default
297
- bias="none",
298
- use_gradient_checkpointing=config.get("gradient_checkpointing", True) or config.get("training", {}).get("gradient_checkpointing", True),
299
- random_state=config.get("seed", 42),
300
- )
301
- logger.info("Unsloth optimizations applied successfully")
302
-
303
- # Set up tokenizer settings
304
- chat_template = config.get("chat_template") or config.get("tokenizer", {}).get("chat_template")
305
- if chat_template:
306
- try:
307
- # Get the correct chat template for phi models
308
- template = get_chat_template("phi")
309
- # Correctly apply the template to the tokenizer (it's a string)
310
- if isinstance(template, str):
311
- tokenizer.chat_template = template
312
- logger.info("Set phi chat template (string)")
313
  else:
314
- # If it's not a string, it's likely already a template object
315
- tokenizer.chat_template = template
316
- logger.info("Set phi chat template (object)")
317
- except Exception as e:
318
- logger.warning(f"Failed to set chat template: {str(e)}")
319
- logger.warning("Chat formatting may not work correctly, but training can continue")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- # Ensure proper token settings
322
- if tokenizer.pad_token_id is None:
323
- tokenizer.pad_token_id = tokenizer.eos_token_id
324
- logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}")
 
 
 
 
325
 
326
  return model, tokenizer
327
-
328
  except Exception as e:
329
- logger.error(f"Error in model/tokenizer loading: {str(e)}")
330
- logger.error("If missing dependencies, check the requirements.txt file")
331
- raise
332
 
333
- def load_dataset_with_mapping(dataset_config):
334
- """Load dataset and apply appropriate column mappings."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  try:
336
- # Load dataset
337
- dataset_name = dataset_config.get("dataset", {}).get("name", "")
338
- dataset_split = dataset_config.get("dataset", {}).get("split", "train")
339
 
340
- if not dataset_name:
341
- raise ValueError("Dataset name not provided in configuration")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
- logger.info(f"Loading pre-processed dataset {dataset_name}, split {dataset_split}")
344
 
345
- try:
346
- dataset = load_dataset(dataset_name, split=dataset_split)
 
 
 
 
 
347
 
348
- # Verify the dataset was actually loaded and is not None
349
- if dataset is None:
350
- raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) loaded as None - check dataset exists and is accessible")
351
-
352
- # Check if the dataset is empty
353
- if len(dataset) == 0:
354
- raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)")
355
-
356
- # Verify conversations field specifically
357
- if "conversations" not in dataset.column_names:
358
- raise ValueError(f"Dataset {dataset_name} missing required 'conversations' column")
359
-
360
- # Validate conversation structure
361
- if len(dataset) > 0:
362
- sample = dataset[0]
363
- conversations = sample.get("conversations", [])
364
-
365
- if conversations:
366
- first_conv = conversations[0]
367
- if isinstance(first_conv, dict):
368
- # Check actual fields
369
- fields = list(first_conv.keys())
370
- logger.info(f"Conversation fields: {fields}")
371
-
372
- # Verify only 'content' field exists
373
- if fields == ["content"]:
374
- logger.info("Confirmed conversations have correct format with only 'content' field")
375
- else:
376
- logger.warning(f"Unexpected conversation fields: {fields}")
377
- logger.warning("Expected only 'content' field")
378
-
379
- # Check a sample of conversation entries to validate structure
380
- logger.info("Validating conversation structure...")
381
- for i in range(min(5, len(dataset))):
382
- conv = dataset[i].get("conversations")
383
- if conv is None:
384
- logger.warning(f"Example {i} has None as 'conversations' value")
385
- elif not isinstance(conv, list):
386
- logger.warning(f"Example {i} has non-list 'conversations': {type(conv)}")
387
- elif len(conv) == 0:
388
- logger.warning(f"Example {i} has empty conversations list")
389
  else:
390
- # Look at the first conversation entry
391
- first_entry = conv[0]
392
- if isinstance(first_entry, dict) and "content" in first_entry:
393
- logger.info(f"Content field example: {str(first_entry['content'])[:50]}...")
394
- else:
395
- logger.warning(f"Example {i} missing 'content' key in conversation")
396
-
397
- except Exception as dataset_error:
398
- logger.error(f"Failed to load dataset {dataset_name}: {str(dataset_error)}")
399
- logger.error("Make sure the dataset exists and you have proper access permissions")
400
- logger.error("This could be due to authentication issues with your HF_TOKEN")
401
- raise
402
 
403
  return dataset
404
 
405
  except Exception as e:
406
- logger.error(f"Error loading dataset: {str(e)}")
407
- return 1
 
408
 
409
  def format_phi_chat(messages, dataset_config):
410
  """Format messages according to phi-4's chat template and dataset config.
@@ -528,110 +592,292 @@ class SimpleDataCollator:
528
  # Return empty batch if no valid examples
529
  return {k: [] for k in batch}
530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  class LoggingCallback(TrainerCallback):
532
  def __init__(self, model=None, dataset=None):
533
  super().__init__()
534
  self.training_started = time.time()
535
  self.last_log_time = time.time()
536
- self.last_step = 0
 
 
537
  self.model = model
538
  self.dataset = dataset
539
 
540
  def on_train_begin(self, args, state, control, **kwargs):
541
- log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
542
-
543
- # Log model info if available
544
- if self.model is not None:
545
- log_info(f"Model parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
- # Log dataset info if available
548
- if self.dataset is not None:
549
- log_info(f"Dataset size: {len(self.dataset)} examples")
 
 
 
 
 
 
 
 
550
 
551
- # Log important training parameters for visibility
552
- total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
553
- total_steps = int(len(self.dataset or []) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs)
554
- log_info(f"Training plan: {len(self.dataset or [])} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps")
555
- log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
556
 
557
- # Log memory information in compact format
558
- if CUDA_AVAILABLE:
559
- memory_info = []
560
- for i in range(NUM_GPUS):
561
- allocated = torch.cuda.memory_allocated(i) / 1024**2
562
- max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
563
- memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
564
- log_info(f"Initial memory usage - {', '.join(memory_info)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
566
  def check_dependencies():
567
- """Check if all required dependencies are installed and in the correct order."""
568
- missing_packages = []
569
- order_issues = []
570
-
571
- # Define required packages with versions
572
  required_packages = {
573
- "unsloth": ">=2024.3",
574
- "transformers": ">=4.38.0",
575
- "peft": ">=0.9.0",
576
- "accelerate": ">=0.27.0"
577
  }
578
 
579
- # Check for required packages
580
- for package, version in required_packages.items():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  try:
 
582
  if package == "unsloth" and not unsloth_available:
583
- missing_packages.append(f"{package}{version}")
 
 
584
  elif package == "peft" and not peft_available:
585
- missing_packages.append(f"{package}{version}")
586
- else:
587
- module = __import__(package)
588
- logger.info(f"Using {package} version {getattr(module, '__version__', 'unknown')}")
 
 
 
 
 
 
589
  except ImportError:
590
- missing_packages.append(f"{package}{version}")
 
591
 
592
- # Check import order
593
- try:
594
- import sys
595
- modules = list(sys.modules.keys())
596
-
597
- if 'transformers' in modules and 'unsloth' in modules:
598
- try:
599
- transformers_idx = modules.index('transformers')
600
- unsloth_idx = modules.index('unsloth')
601
- if transformers_idx < unsloth_idx:
602
- order_issues.append("For optimal performance, unsloth should be imported before transformers")
603
- except ValueError:
604
- pass
605
- except Exception as e:
606
- logger.warning(f"Could not check module import order: {str(e)}")
607
-
608
- # Check optional dependencies
609
- optional_packages = {
610
- "flash_attn": "Flash attention support",
611
- "bitsandbytes": "4-bit quantization support"
612
- }
613
 
614
- for package, feature in optional_packages.items():
615
- if find_spec(package):
616
- logger.info(f"Found {package} - {feature} enabled")
617
- else:
618
- logger.warning(f"{package} not found - {feature} will not be available")
 
 
 
 
 
 
 
 
 
 
 
619
 
620
  # Report missing required packages
621
  if missing_packages:
622
- logger.error("Critical dependencies missing:")
623
  for pkg in missing_packages:
624
- logger.error(f" - {pkg}")
625
- logger.error("Please install the missing dependencies with:")
626
- logger.error(f" pip install {' '.join(missing_packages)}")
627
  return False
628
 
629
- # Report order issues as warnings
630
- for issue in order_issues:
631
- logger.warning(issue)
632
-
633
  return True
634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  def update_huggingface_space():
636
  """Update the Hugging Face Space with the current code."""
637
  log_info("Updating Hugging Face Space...")
@@ -709,381 +955,338 @@ def validate_huggingface_credentials():
709
  logger.warning(f"Error validating Hugging Face credentials: {str(e)}")
710
  return False
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def main():
 
 
 
 
 
 
713
  # Set up logging
714
  logger.info("Starting training process")
715
 
716
  try:
717
- # Check dependencies first, before any other operations
718
- if not check_dependencies():
719
- logger.error("Aborting due to missing critical dependencies")
720
- return 1
721
-
722
- # Parse arguments
723
  args = parse_args()
724
 
725
- # Load environment variables
726
- load_env_variables()
727
 
728
- # Validate Hugging Face credentials if we're going to use them
729
- validate_huggingface_credentials()
730
-
731
- # Load configuration
732
  try:
733
- transformers_config = load_configs(args.config)
734
- hardware_config = transformers_config.get("hardware", {})
735
- dataset_config = transformers_config.get("dataset", {})
736
- logger.info("Configuration loaded successfully")
737
  except Exception as e:
738
- logger.error(f"Error loading configuration: {e}")
739
  return 1
740
 
741
- # Check if we're in distributed mode
742
- is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
743
- if is_distributed:
744
- local_rank = int(os.environ.get("LOCAL_RANK", "0"))
745
- log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}")
746
- else:
747
- log_info("Running in non-distributed mode (single process)")
748
-
749
- # Set random seed for reproducibility
750
- seed = transformers_config.get("seed", 42)
751
- set_seed(seed)
752
- logger.info(f"Set random seed to {seed}")
753
-
754
- # Load model and tokenizer using the consolidated config
755
- model, tokenizer = load_model_and_tokenizer(transformers_config)
756
-
757
- # Empty CUDA cache to ensure clean state
758
- if CUDA_AVAILABLE:
759
- torch.cuda.empty_cache()
760
- log_info("Cleared CUDA cache")
761
 
762
- # Setup environment variable for CUDA memory allocation
763
- if CUDA_AVAILABLE:
764
- system_settings = hardware_config.get("system_settings", {})
765
- cuda_memory_fraction = system_settings.get("cuda_memory_fraction", 0.85)
 
 
766
 
767
- if cuda_memory_fraction < 1.0:
768
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
769
- log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
 
 
 
 
 
 
770
 
 
771
  try:
772
- log_info("Loading dataset...")
773
- dataset = load_dataset_with_mapping(dataset_config)
774
-
775
- # Extra validation to catch None/empty dataset issues
776
- if dataset is None:
777
- logger.error("Dataset is None! Cannot proceed with training.")
778
- return 1
779
-
780
- if not hasattr(dataset, '__len__') or len(dataset) == 0:
781
- logger.error("Dataset is empty! Cannot proceed with training.")
782
- return 1
783
-
784
- log_info(f"Dataset loaded with {len(dataset)} examples")
785
-
786
- # Minimal validation before proceeding
787
- if dataset is None or len(dataset) == 0:
788
- logger.error("Dataset is empty or None! Cannot proceed with training.")
789
- return 1
790
-
791
- # Create data collator
792
- data_collator = SimpleDataCollator(tokenizer, dataset_config)
793
-
794
- # Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
795
- # First check hardware config, then transformers config
796
- use_bf16 = False
797
- use_fp16 = False
798
-
799
- # Check hardware config first
800
- hardware_precision = hardware_config.get("training_optimizations", {}).get("mixed_precision", "")
801
- if hardware_precision.lower() == "bf16":
802
- use_bf16 = True
803
- log_info("Using BF16 precision from hardware config")
804
- elif hardware_precision.lower() == "fp16":
805
- use_fp16 = True
806
- log_info("Using FP16 precision from hardware config")
807
- else:
808
- # Fall back to transformers config
809
- use_bf16 = transformers_config.get("bf16", False) or transformers_config.get("torch_dtype", "") == "bfloat16"
810
- use_fp16 = transformers_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set
811
- log_info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}")
812
-
813
- # Get per device batch size - from transformers config, but possibly overridden by hardware config
814
- per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16)
815
- gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3)
816
-
817
- # Get multi-GPU strategy from hardware config (default to data_parallel)
818
- multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel")
819
- logger.info(f"Multi-GPU strategy: {multi_gpu_strategy}")
820
-
821
- # For multi-GPU setup, adjust for better balance
822
- if CUDA_AVAILABLE and NUM_GPUS > 1:
823
- log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs")
824
-
825
- # Set up FSDP for multi-GPU training if specified and in distributed mode
826
- fsdp_config = None
827
- if multi_gpu_strategy == "fsdp" and is_distributed and NUM_GPUS > 1:
828
- try:
829
- from torch.distributed.fsdp import (
830
- FullyShardedDataParallel as FSDP,
831
- MixedPrecision,
832
- BackwardPrefetch,
833
- ShardingStrategy,
834
- CPUOffload,
835
- )
836
- from torch.distributed.fsdp.wrap import (
837
- transformer_auto_wrap_policy,
838
- enable_wrap,
839
- wrap,
840
- )
841
-
842
- log_info("Using FSDP for distributed training")
843
-
844
- # Configure FSDP
845
- fsdp_config = {
846
- "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],
847
- "fsdp_offload_params": False,
848
- "fsdp_backward_prefetch": "BACKWARD_PRE",
849
- "fsdp_min_num_params": 1e6,
850
- "fsdp_sharding_strategy": 1, # FULL_SHARD
851
- }
852
-
853
- if use_bf16 or use_fp16:
854
- precision_type = "bf16" if use_bf16 else "fp16"
855
- fsdp_config["fsdp_state_dict_type"] = "FULL_STATE_DICT"
856
- log_info(f"FSDP using mixed precision: {precision_type}")
857
- except ImportError:
858
- log_info("FSDP imports failed, falling back to standard DDP")
859
- fsdp_config = None
860
- elif multi_gpu_strategy == "fsdp" and not is_distributed:
861
- log_info("FSDP disabled: requires distributed environment (use torchrun or accelerate)")
862
- log_info("Using DataParallel for multi-GPU training instead")
863
- else:
864
- log_info(f"Using {multi_gpu_strategy} for multi-GPU training")
865
-
866
- # Get system settings from hardware config
867
- dataloader_workers = hardware_config.get("system_settings", {}).get("dataloader_num_workers", 2)
868
- pin_memory = hardware_config.get("system_settings", {}).get("dataloader_pin_memory", True)
869
-
870
- # Set up training arguments
871
- log_info("Setting up training arguments")
872
-
873
- # Handle FSDP configuration
874
- fsdp_config = transformers_config.get("distributed_training", {}).get("fsdp_config", {})
875
- fsdp_enabled = fsdp_config.get("enabled", False)
876
-
877
- # Only set FSDP args if explicitly enabled
878
- fsdp_args = None
879
- if fsdp_enabled and is_distributed and NUM_GPUS > 1:
880
- fsdp_args = {
881
- "fsdp": ["full_shard", "auto_wrap"],
882
- "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
883
- "fsdp_offload_params": fsdp_config.get("offload_params", False),
884
- "fsdp_state_dict_type": "FULL_STATE_DICT",
885
- "fsdp_sharding_strategy": 1, # FULL_SHARD
886
- }
887
- log_info("FSDP configuration enabled")
888
- else:
889
- log_info("FSDP disabled, using standard data parallel")
890
-
891
- # Check if we're running in a Space
892
- is_space = bool(os.environ.get("SPACE_ID"))
893
-
894
- # Create training arguments
895
- training_args = TrainingArguments(
896
- output_dir=transformers_config.get("output_dir", "./results") or transformers_config.get("checkpointing", {}).get("output_dir", "./results"),
897
- num_train_epochs=transformers_config.get("training", {}).get("num_train_epochs", 3),
898
- per_device_train_batch_size=per_device_batch_size,
899
- gradient_accumulation_steps=gradient_accumulation_steps,
900
- learning_rate=transformers_config.get("training", {}).get("learning_rate", 2e-5),
901
- weight_decay=transformers_config.get("training", {}).get("weight_decay", 0.01),
902
- warmup_ratio=transformers_config.get("training", {}).get("warmup_ratio", 0.05),
903
- lr_scheduler_type=transformers_config.get("training", {}).get("lr_scheduler_type", "cosine"),
904
- logging_steps=transformers_config.get("training", {}).get("logging_steps", 10),
905
- save_strategy=transformers_config.get("checkpointing", {}).get("save_strategy", "steps"),
906
- save_steps=transformers_config.get("checkpointing", {}).get("save_steps", 100),
907
- save_total_limit=transformers_config.get("checkpointing", {}).get("save_total_limit", 3),
908
- fp16=use_fp16,
909
- bf16=use_bf16,
910
- max_grad_norm=transformers_config.get("training", {}).get("max_grad_norm", 1.0),
911
- push_to_hub=transformers_config.get("huggingface_hub", {}).get("push_to_hub", False),
912
- hub_model_id=transformers_config.get("huggingface_hub", {}).get("hub_model_id", None),
913
- hub_token=None if is_space else os.environ.get("HF_TOKEN", None),
914
- report_to="tensorboard",
915
- remove_unused_columns=False, # Keep all columns
916
- gradient_checkpointing=transformers_config.get("training", {}).get("gradient_checkpointing", True),
917
- dataloader_pin_memory=pin_memory,
918
- optim=transformers_config.get("training", {}).get("optim", "adamw_torch"),
919
- ddp_find_unused_parameters=False, # Improve distributed training efficiency
920
- dataloader_drop_last=False, # Process all examples
921
- dataloader_num_workers=dataloader_workers,
922
- no_cuda=False if CUDA_AVAILABLE else True, # Use CUDA if available
923
- **({} if fsdp_args is None else fsdp_args) # Only include FSDP args if configured
924
- )
925
-
926
- log_info("Training arguments created successfully")
927
-
928
- # Validate dataset before creating sampler
929
- if dataset is None:
930
- raise ValueError("Dataset is None - cannot create sampler")
931
-
932
- # Create sequential sampler to maintain original dataset order
933
- sequential_sampler = torch.utils.data.SequentialSampler(dataset)
934
- log_info("Sequential sampler created")
935
-
936
- # Initialize trainer first
937
- log_info("Initializing Trainer")
938
- trainer = Trainer(
939
- model=model,
940
- args=training_args,
941
- train_dataset=dataset,
942
- data_collator=data_collator,
943
- callbacks=[LoggingCallback(model=model, dataset=dataset)],
944
- )
945
-
946
- # Then override the get_train_dataloader method
947
- def custom_get_train_dataloader():
948
- """Custom dataloader that preserves original dataset order"""
949
- log_info("Creating sequential dataloader to maintain original dataset order")
950
-
951
- # Safety check - make sure dataset exists and is not None
952
- if dataset is None:
953
- raise ValueError("Dataset is None - cannot create dataloader")
954
-
955
- # Make sure dataset is not empty
956
- if len(dataset) == 0:
957
- raise ValueError("Dataset is empty - cannot create dataloader")
958
-
959
- # Create a simple sequential sampler
960
- sequential_sampler = torch.utils.data.SequentialSampler(dataset)
961
-
962
- # Verification of sequence preservation flags - simplified
963
- data_loading_config = dataset_config.get("data_loading", {})
964
- shuffle_enabled = data_loading_config.get("shuffle", False)
965
-
966
- if shuffle_enabled:
967
- log_info("WARNING: Shuffle is enabled in configuration! This will be overridden to preserve order.")
968
- # We enforce sequential processing regardless of config
969
-
970
- # Log our approach clearly
971
- log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
972
-
973
- # Verify column order and check for 'conversations' field
974
- expected_order = ["prompt_number", "article_id", "conversations"]
975
- if hasattr(dataset, 'column_names'):
976
- actual_order = dataset.column_names
977
-
978
- # Verify all required fields exist
979
- missing_fields = [field for field in ["conversations"] if field not in actual_order]
980
- if missing_fields:
981
- raise ValueError(f"Dataset missing critical fields: {missing_fields}")
982
-
983
- if actual_order == expected_order:
984
- log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
985
- else:
986
- log_info(f"Note: Dataset columns ({', '.join(actual_order)}) are not in expected order ({', '.join(expected_order)})")
987
- log_info("This is handled correctly by field-based access, but noting for clarity")
988
-
989
- log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
990
-
991
- # Validate a few samples before proceeding
992
- for i in range(min(3, len(dataset))):
993
- sample = dataset[i]
994
- if "conversations" not in sample:
995
- log_info(f"WARNING: Sample {i} missing 'conversations' field")
996
- elif sample["conversations"] is None:
997
- log_info(f"WARNING: Sample {i} has None 'conversations' field")
998
- elif not isinstance(sample["conversations"], list):
999
- log_info(f"WARNING: Sample {i} has non-list 'conversations' field: {type(sample['conversations'])}")
1000
-
1001
- # Calculate batch size based on device availability
1002
- if getattr(training_args, "no_cuda", False):
1003
- batch_size = training_args.per_device_train_batch_size
1004
- else:
1005
- batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1)
1006
-
1007
- log_info(f"Using sequential sampler with batch size {batch_size}")
1008
-
1009
- # Return DataLoader with sequential sampler and extra error handling
1010
- try:
1011
- return torch.utils.data.DataLoader(
1012
- dataset,
1013
- batch_size=batch_size,
1014
- sampler=sequential_sampler, # Always use sequential sampler
1015
- collate_fn=data_collator,
1016
- drop_last=training_args.dataloader_drop_last,
1017
- num_workers=training_args.dataloader_num_workers,
1018
- pin_memory=training_args.dataloader_pin_memory,
1019
- )
1020
- except Exception as e:
1021
- log_info(f"Error creating DataLoader: {str(e)}")
1022
- # Try again with minimal settings
1023
- log_info("Attempting to create DataLoader with minimal settings")
1024
- return torch.utils.data.DataLoader(
1025
- dataset,
1026
- batch_size=1, # Minimal batch size
1027
- sampler=sequential_sampler,
1028
- collate_fn=data_collator,
1029
- num_workers=0, # No parallel workers
1030
- pin_memory=False,
1031
- )
1032
-
1033
- # Override the get_train_dataloader method
1034
- trainer.get_train_dataloader = custom_get_train_dataloader
1035
-
1036
- # Start training
1037
- log_info("=== Starting Training ===")
1038
- try:
1039
- # Empty cache again right before training
1040
- if CUDA_AVAILABLE:
1041
- torch.cuda.empty_cache()
1042
- log_info("Cleared CUDA cache before training")
1043
-
1044
- # Display compact training info
1045
- total_steps = int((len(dataset) / (per_device_batch_size * NUM_GPUS * gradient_accumulation_steps)) * training_args.num_train_epochs)
1046
- log_info(f"Training plan: {len(dataset)} examples over {training_args.num_train_epochs} epochs ≈ {total_steps} steps")
1047
-
1048
- trainer.train()
1049
- log_info("Training completed successfully!")
1050
-
1051
- # Save the final model
1052
- log_info("Saving final model...")
1053
- trainer.save_model()
1054
- log_info(f"Model saved to {training_args.output_dir}")
1055
-
1056
- # Push to hub if enabled
1057
- if transformers_config.get("huggingface_hub", {}).get("push_to_hub", False):
1058
- hub_id = transformers_config.get("huggingface_hub", {}).get("hub_model_id", "model")
1059
- log_info(f"Pushing model to Hugging Face Hub as {hub_id}...")
1060
- trainer.push_to_hub()
1061
- log_info("Model successfully pushed to Hub")
1062
-
1063
- # Update the Hugging Face Space with current code
1064
- if os.environ.get("HF_TOKEN") and os.environ.get("HF_USERNAME") and os.environ.get("HF_SPACE_NAME"):
1065
- update_huggingface_space()
1066
-
1067
- return 0
1068
- except Exception as e:
1069
- logger.error(f"Training failed with error: {str(e)}")
1070
- # Log CUDA memory info if available in compact format
1071
- if CUDA_AVAILABLE:
1072
- memory_info = []
1073
- for i in range(NUM_GPUS):
1074
- allocated = torch.cuda.memory_allocated(i) / 1024**2
1075
- reserved = torch.cuda.memory_reserved(i) / 1024**2
1076
- max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
1077
- memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB (max: {max_mem:.1f}MB)")
1078
- logger.error(f"GPU memory at failure: {', '.join(memory_info)}")
1079
- raise
1080
 
 
 
 
1081
  except Exception as e:
1082
- logger.error(f"Error in main training loop: {str(e)}")
 
 
 
 
1083
  return 1
1084
 
1085
  except Exception as e:
1086
  logger.error(f"Error in main function: {str(e)}")
 
1087
  return 1
1088
 
1089
  if __name__ == "__main__":
 
184
  raise
185
 
186
  def parse_args():
187
+ """
188
+ Parse command line arguments for the training script.
189
+
190
+ Returns:
191
+ argparse.Namespace: The parsed command line arguments
192
+ """
193
+ parser = argparse.ArgumentParser(description="Run training for language models")
194
+ parser.add_argument(
195
+ "--config_file",
196
+ type=str,
197
+ default=None,
198
+ help="Path to the configuration file (default: transformers_config.json in script directory)"
199
+ )
200
+ parser.add_argument(
201
+ "--seed",
202
+ type=int,
203
+ default=None,
204
+ help="Random seed for reproducibility (default: based on current time)"
205
+ )
206
+ parser.add_argument(
207
+ "--log_level",
208
+ type=str,
209
+ choices=["debug", "info", "warning", "error", "critical"],
210
+ default="info",
211
+ help="Logging level (default: info)"
212
+ )
213
  return parser.parse_args()
214
 
215
  def load_model_and_tokenizer(config):
216
+ """
217
+ Load the model and tokenizer according to the configuration.
218
+
219
+ Args:
220
+ config (dict): Complete configuration dictionary
221
+
222
+ Returns:
223
+ tuple: (model, tokenizer) - The loaded model and tokenizer
224
+ """
225
+ # Extract model configuration
226
+ model_config = get_config_value(config, "model", {})
227
+ model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit")
228
+ use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True)
229
+ trust_remote_code = get_config_value(model_config, "trust_remote_code", True)
230
+ model_revision = get_config_value(config, "model_revision", "main")
231
+
232
+ # Unsloth configuration
233
+ unsloth_config = get_config_value(config, "unsloth", {})
234
+ unsloth_enabled = get_config_value(unsloth_config, "enabled", True)
235
+
236
+ # Tokenizer configuration
237
+ tokenizer_config = get_config_value(config, "tokenizer", {})
238
+ max_seq_length = min(
239
+ get_config_value(tokenizer_config, "max_seq_length", 2048),
240
+ 4096 # Maximum supported by most models
241
+ )
242
+ add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True)
243
+ chat_template = get_config_value(tokenizer_config, "chat_template", None)
244
+ padding_side = get_config_value(tokenizer_config, "padding_side", "right")
245
+
246
+ log_info(f"Loading model: {model_name} (revision: {model_revision})")
247
+ log_info(f"Max sequence length: {max_seq_length}")
248
+
249
  try:
250
+ if unsloth_enabled and unsloth_available:
251
+ log_info("Using Unsloth for 4-bit quantized model and LoRA")
252
+ # Load using Unsloth
253
+ from unsloth import FastLanguageModel
254
+ model, tokenizer = FastLanguageModel.from_pretrained(
255
+ model_name=model_name,
256
+ max_seq_length=max_seq_length,
257
+ dtype=get_config_value(config, "torch_dtype", "bfloat16"),
258
+ revision=model_revision,
259
+ trust_remote_code=trust_remote_code,
260
+ use_flash_attention_2=get_config_value(config, "use_flash_attention", True)
261
+ )
262
 
263
+ # Configure tokenizer settings
264
+ tokenizer.padding_side = padding_side
265
+ if add_eos_token and tokenizer.eos_token is None:
266
+ log_info("Setting EOS token")
267
+ tokenizer.add_special_tokens({"eos_token": "</s>"})
268
+
269
+ # Set chat template if specified
270
+ if chat_template:
271
+ log_info(f"Setting chat template: {chat_template}")
272
+ if hasattr(tokenizer, "chat_template"):
273
+ tokenizer.chat_template = chat_template
274
+ else:
275
+ log_info("Tokenizer does not support chat templates, using default formatting")
276
+
277
+ # Apply LoRA
278
+ lora_r = get_config_value(unsloth_config, "r", 16)
279
+ lora_alpha = get_config_value(unsloth_config, "alpha", 32)
280
+ lora_dropout = get_config_value(unsloth_config, "dropout", 0)
281
+ target_modules = get_config_value(unsloth_config, "target_modules",
282
+ ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
283
+
284
+ log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
285
+ model = FastLanguageModel.get_peft_model(
286
+ model,
287
+ r=lora_r,
288
+ target_modules=target_modules,
289
+ lora_alpha=lora_alpha,
290
+ lora_dropout=lora_dropout,
291
+ bias="none",
292
+ use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True),
293
+ random_state=0,
294
+ max_seq_length=max_seq_length,
295
+ modules_to_save=None
296
+ )
297
  else:
298
+ # Standard HuggingFace loading
299
+ log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
300
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ # Load tokenizer first
303
+ tokenizer = AutoTokenizer.from_pretrained(
304
+ model_name,
305
+ trust_remote_code=trust_remote_code,
306
+ use_fast=use_fast_tokenizer,
307
+ revision=model_revision,
308
+ padding_side=padding_side
309
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
+ # Configure tokenizer settings
312
+ if add_eos_token and tokenizer.eos_token is None:
313
+ log_info("Setting EOS token")
314
+ tokenizer.add_special_tokens({"eos_token": "</s>"})
315
+
316
+ # Set chat template if specified
317
+ if chat_template:
318
+ log_info(f"Setting chat template: {chat_template}")
319
+ if hasattr(tokenizer, "chat_template"):
320
+ tokenizer.chat_template = chat_template
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  else:
322
+ log_info("Tokenizer does not support chat templates, using default formatting")
323
+
324
+ # Now load model with updated tokenizer
325
+ model = AutoModelForCausalLM.from_pretrained(
326
+ model_name,
327
+ trust_remote_code=trust_remote_code,
328
+ revision=model_revision,
329
+ torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
330
+ device_map="auto" if CUDA_AVAILABLE else None
331
+ )
332
+
333
+ # Apply PEFT/LoRA if enabled but using standard loading
334
+ if peft_available and get_config_value(unsloth_config, "enabled", True):
335
+ log_info("Applying standard PEFT/LoRA configuration")
336
+ from peft import LoraConfig, get_peft_model
337
+
338
+ lora_r = get_config_value(unsloth_config, "r", 16)
339
+ lora_alpha = get_config_value(unsloth_config, "alpha", 32)
340
+ lora_dropout = get_config_value(unsloth_config, "dropout", 0)
341
+ target_modules = get_config_value(unsloth_config, "target_modules",
342
+ ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
343
+
344
+ log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
345
+ lora_config = LoraConfig(
346
+ r=lora_r,
347
+ lora_alpha=lora_alpha,
348
+ target_modules=target_modules,
349
+ lora_dropout=lora_dropout,
350
+ bias="none",
351
+ task_type="CAUSAL_LM"
352
+ )
353
+ model = get_peft_model(model, lora_config)
354
 
355
+ # Print model summary
356
+ log_info(f"Model loaded successfully: {model.__class__.__name__}")
357
+ if hasattr(model, "print_trainable_parameters"):
358
+ model.print_trainable_parameters()
359
+ else:
360
+ total_params = sum(p.numel() for p in model.parameters())
361
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
362
+ log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})")
363
 
364
  return model, tokenizer
365
+
366
  except Exception as e:
367
+ log_info(f"Error loading model: {str(e)}")
368
+ traceback.print_exc()
369
+ return None, None
370
 
371
+ def load_dataset_with_mapping(config):
372
+ """
373
+ Load dataset from Hugging Face or local files and apply necessary transformations.
374
+
375
+ Args:
376
+ config (dict): Dataset configuration dictionary
377
+
378
+ Returns:
379
+ Dataset: The loaded and processed dataset
380
+ """
381
+ # Extract dataset configuration
382
+ dataset_info = get_config_value(config, "dataset", {})
383
+ dataset_name = get_config_value(dataset_info, "name", None)
384
+ dataset_split = get_config_value(dataset_info, "split", "train")
385
+
386
+ # Data formatting configuration
387
+ formatting_config = get_config_value(config, "data_formatting", {})
388
+
389
+ if not dataset_name:
390
+ raise ValueError("Dataset name not specified in config")
391
+
392
+ log_info(f"Loading dataset: {dataset_name} (split: {dataset_split})")
393
+
394
  try:
395
+ # Load dataset from Hugging Face or local path
396
+ from datasets import load_dataset
 
397
 
398
+ # Check if it's a local path or Hugging Face dataset
399
+ if os.path.exists(dataset_name) or os.path.exists(os.path.join(os.getcwd(), dataset_name)):
400
+ log_info(f"Loading dataset from local path: {dataset_name}")
401
+ # Local dataset - check if it's a directory or file
402
+ if os.path.isdir(dataset_name):
403
+ # Directory - look for data files
404
+ dataset = load_dataset(
405
+ "json",
406
+ data_files={"train": os.path.join(dataset_name, "*.json")},
407
+ split=dataset_split
408
+ )
409
+ else:
410
+ # Single file
411
+ dataset = load_dataset(
412
+ "json",
413
+ data_files={"train": dataset_name},
414
+ split=dataset_split
415
+ )
416
+ else:
417
+ # Hugging Face dataset
418
+ log_info(f"Loading dataset from Hugging Face: {dataset_name}")
419
+ dataset = load_dataset(dataset_name, split=dataset_split)
420
 
421
+ log_info(f"Dataset loaded with {len(dataset)} examples")
422
 
423
+ # Check if dataset contains required fields
424
+ required_fields = ["conversations"]
425
+ missing_fields = [field for field in required_fields if field not in dataset.column_names]
426
+
427
+ if missing_fields:
428
+ log_info(f"WARNING: Dataset missing required fields: {missing_fields}")
429
+ log_info("Attempting to map dataset structure to required format")
430
 
431
+ # Implement conversion logic based on dataset structure
432
+ if "messages" in dataset.column_names:
433
+ log_info("Converting 'messages' field to 'conversations' format")
434
+ dataset = dataset.map(
435
+ lambda x: {"conversations": x["messages"]},
436
+ remove_columns=["messages"]
437
+ )
438
+ elif "text" in dataset.column_names:
439
+ log_info("Converting plain text to conversations format")
440
+ dataset = dataset.map(
441
+ lambda x: {"conversations": [{"role": "user", "content": x["text"]}]},
442
+ remove_columns=["text"]
443
+ )
444
+ else:
445
+ raise ValueError(f"Cannot convert dataset format - missing required fields and no conversion path available")
446
+
447
+ # Log dataset info
448
+ log_info(f"Dataset has {len(dataset)} examples and columns: {dataset.column_names}")
449
+
450
+ # Show a few examples for verification
451
+ for i in range(min(3, len(dataset))):
452
+ example = dataset[i]
453
+ log_info(f"Example {i}:")
454
+ for key, value in example.items():
455
+ if key == "conversations":
456
+ log_info(f" conversations: {len(value)} messages")
457
+ # Show first message only to avoid cluttering logs
458
+ if value and len(value) > 0:
459
+ first_msg = value[0]
460
+ if isinstance(first_msg, dict) and "content" in first_msg:
461
+ content = first_msg["content"]
462
+ log_info(f" First message: {content[:50]}..." if len(content) > 50 else f" First message: {content}")
 
 
 
 
 
 
 
 
 
463
  else:
464
+ log_info(f" {key}: {value}")
 
 
 
 
 
 
 
 
 
 
 
465
 
466
  return dataset
467
 
468
  except Exception as e:
469
+ log_info(f"Error loading dataset: {str(e)}")
470
+ traceback.print_exc()
471
+ return None
472
 
473
  def format_phi_chat(messages, dataset_config):
474
  """Format messages according to phi-4's chat template and dataset config.
 
592
  # Return empty batch if no valid examples
593
  return {k: [] for k in batch}
594
 
595
+ def log_gpu_memory_usage(step=None, frequency=50, clear_cache_threshold=0.9, label=None):
596
+ """
597
+ Log GPU memory usage statistics with optional cache clearing
598
+
599
+ Args:
600
+ step: Current training step (if None, logs regardless of frequency)
601
+ frequency: How often to log when step is provided
602
+ clear_cache_threshold: Fraction of memory used that triggers cache clearing (0-1)
603
+ label: Optional label for the log message (e.g., "Initial", "Error", "Step")
604
+ """
605
+ if not CUDA_AVAILABLE:
606
+ return
607
+
608
+ # Only log every 'frequency' steps if step is provided
609
+ if step is not None and frequency > 0 and step % frequency != 0:
610
+ return
611
+
612
+ # Get memory usage for each GPU
613
+ memory_info = []
614
+ for i in range(NUM_GPUS):
615
+ allocated = torch.cuda.memory_allocated(i) / (1024 ** 2) # MB
616
+ reserved = torch.cuda.memory_reserved(i) / (1024 ** 2) # MB
617
+ max_mem = torch.cuda.max_memory_allocated(i) / (1024 ** 2) # MB
618
+
619
+ # Calculate percentage of reserved memory that's allocated
620
+ usage_percent = (allocated / reserved) * 100 if reserved > 0 else 0
621
+ memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB ({usage_percent:.1f}%, max: {max_mem:.1f}MB)")
622
+
623
+ # Automatically clear cache if over threshold
624
+ if clear_cache_threshold > 0 and reserved > 0 and (allocated / reserved) > clear_cache_threshold:
625
+ log_info(f"Clearing CUDA cache for GPU {i} - high utilization ({allocated:.1f}/{reserved:.1f}MB)")
626
+ with torch.cuda.device(i):
627
+ torch.cuda.empty_cache()
628
+
629
+ prefix = f"{label} " if label else ""
630
+ log_info(f"{prefix}GPU Memory: {', '.join(memory_info)}")
631
+
632
  class LoggingCallback(TrainerCallback):
633
  def __init__(self, model=None, dataset=None):
634
  super().__init__()
635
  self.training_started = time.time()
636
  self.last_log_time = time.time()
637
+ self.last_step_time = None
638
+ self.step_durations = []
639
+ self.best_loss = float('inf')
640
  self.model = model
641
  self.dataset = dataset
642
 
643
  def on_train_begin(self, args, state, control, **kwargs):
644
+ """Called at the beginning of training"""
645
+ try:
646
+ log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
647
+
648
+ # Log model info if available
649
+ if self.model is not None:
650
+ total_params = sum(p.numel() for p in self.model.parameters())
651
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
652
+ log_info(f"Model parameters: {total_params/1e6:.2f}M total, {trainable_params/1e6:.2f}M trainable")
653
+
654
+ # Log dataset info if available
655
+ if self.dataset is not None:
656
+ log_info(f"Dataset size: {len(self.dataset)} examples")
657
+
658
+ # Log important training parameters for visibility
659
+ total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
660
+ total_steps = int(len(self.dataset or []) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs)
661
+ log_info(f"Training plan: {len(self.dataset or [])} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps")
662
+ log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
663
+
664
+ # Log initial GPU memory usage with label
665
+ log_gpu_memory_usage(label="Initial")
666
+ except Exception as e:
667
+ logger.warning(f"Error logging training begin statistics: {str(e)}")
668
+
669
+ def on_step_end(self, args, state, control, **kwargs):
670
+ """Called at the end of each step"""
671
+ try:
672
+ if state.global_step == 1 or state.global_step % args.logging_steps == 0:
673
+ # Track step timing
674
+ current_time = time.time()
675
+ if self.last_step_time:
676
+ step_duration = current_time - self.last_step_time
677
+ self.step_durations.append(step_duration)
678
+ # Keep only last 100 steps for averaging
679
+ if len(self.step_durations) > 100:
680
+ self.step_durations.pop(0)
681
+ avg_step_time = sum(self.step_durations) / len(self.step_durations)
682
+ log_info(f"Step {state.global_step}: {step_duration:.2f}s (avg: {avg_step_time:.2f}s)")
683
+
684
+ self.last_step_time = current_time
685
+
686
+ # Log GPU memory usage with step number
687
+ log_gpu_memory_usage(state.global_step, args.logging_steps)
688
+
689
+ # Log loss
690
+ if state.log_history:
691
+ latest_logs = state.log_history[-1] if state.log_history else {}
692
+ if "loss" in latest_logs:
693
+ loss = latest_logs["loss"]
694
+ log_info(f"Step {state.global_step} loss: {loss:.4f}")
695
+
696
+ # Track best loss
697
+ if loss < self.best_loss:
698
+ self.best_loss = loss
699
+ log_info(f"New best loss: {loss:.4f}")
700
+ except Exception as e:
701
+ logger.warning(f"Error logging step end statistics: {str(e)}")
702
+
703
+ def on_train_end(self, args, state, control, **kwargs):
704
+ """Called at the end of training"""
705
+ try:
706
+ # Calculate training duration
707
+ training_time = time.time() - self.training_started
708
+ hours, remainder = divmod(training_time, 3600)
709
+ minutes, seconds = divmod(remainder, 60)
710
+
711
+ log_info(f"=== Training completed at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
712
+ log_info(f"Training duration: {int(hours)}h {int(minutes)}m {int(seconds)}s")
713
+ log_info(f"Final step: {state.global_step}")
714
+ log_info(f"Best loss: {self.best_loss:.4f}")
715
+
716
+ # Log final GPU memory usage
717
+ log_gpu_memory_usage(label="Final")
718
+ except Exception as e:
719
+ logger.warning(f"Error logging training end statistics: {str(e)}")
720
 
721
+ # Other callback methods with proper error handling
722
+ def on_save(self, args, state, control, **kwargs):
723
+ """Called when a checkpoint is saved"""
724
+ try:
725
+ log_info(f"Saving checkpoint at step {state.global_step}")
726
+ except Exception as e:
727
+ logger.warning(f"Error in on_save: {str(e)}")
728
+
729
+ def on_log(self, args, state, control, **kwargs):
730
+ """Called when a log is created"""
731
+ pass
732
 
733
+ def on_evaluate(self, args, state, control, **kwargs):
734
+ """Called when evaluation is performed"""
735
+ pass
 
 
736
 
737
+ # Only implement the methods we actually need, remove the others
738
+ def on_prediction_step(self, args, state, control, **kwargs):
739
+ """Called when prediction is performed"""
740
+ pass
741
+
742
+ def on_save_model(self, args, state, control, **kwargs):
743
+ """Called when model is saved"""
744
+ try:
745
+ # Log memory usage after saving
746
+ log_gpu_memory_usage(label=f"Save at step {state.global_step}")
747
+ except Exception as e:
748
+ logger.warning(f"Error in on_save_model: {str(e)}")
749
+
750
+ def on_epoch_end(self, args, state, control, **kwargs):
751
+ """Called at the end of an epoch"""
752
+ try:
753
+ epoch = state.epoch
754
+ log_info(f"Completed epoch {epoch:.2f}")
755
+ log_gpu_memory_usage(label=f"Epoch {epoch:.2f}")
756
+ except Exception as e:
757
+ logger.warning(f"Error in on_epoch_end: {str(e)}")
758
+
759
+ def on_step_begin(self, args, state, control, **kwargs):
760
+ """Called at the beginning of a step"""
761
+ pass
762
 
763
  def check_dependencies():
764
+ """
765
+ Check for required and optional dependencies, ensuring proper versions and import order.
766
+ Returns True if all required dependencies are present, False otherwise.
767
+ """
768
+ # Define required packages with versions and descriptions
769
  required_packages = {
770
+ "unsloth": {"version": ">=2024.3", "feature": "fast 4-bit quantization and LoRA"},
771
+ "transformers": {"version": ">=4.38.0", "feature": "core model functionality"},
772
+ "peft": {"version": ">=0.9.0", "feature": "parameter-efficient fine-tuning"},
773
+ "accelerate": {"version": ">=0.27.0", "feature": "multi-GPU training"}
774
  }
775
 
776
+ # Optional packages that enhance functionality
777
+ optional_packages = {
778
+ "flash_attn": {"feature": "faster attention computation"},
779
+ "bitsandbytes": {"feature": "quantization support"},
780
+ "optimum": {"feature": "model optimization"},
781
+ "wandb": {"feature": "experiment tracking"}
782
+ }
783
+
784
+ # Store results
785
+ missing_packages = []
786
+ package_versions = {}
787
+ order_issues = []
788
+
789
+ # Check required packages
790
+ log_info("Checking required dependencies...")
791
+ for package, info in required_packages.items():
792
+ version_req = info["version"]
793
+ feature = info["feature"]
794
+
795
  try:
796
+ # Special handling for packages we've already checked
797
  if package == "unsloth" and not unsloth_available:
798
+ missing_packages.append(f"{package}{version_req}")
799
+ log_info(f"❌ {package} - {feature} MISSING")
800
+ continue
801
  elif package == "peft" and not peft_available:
802
+ missing_packages.append(f"{package}{version_req}")
803
+ log_info(f"❌ {package} - {feature} MISSING")
804
+ continue
805
+
806
+ # Try to import and get version
807
+ module = __import__(package)
808
+ version = getattr(module, "__version__", "unknown")
809
+ package_versions[package] = version
810
+ log_info(f"✅ {package} v{version} - {feature}")
811
+
812
  except ImportError:
813
+ missing_packages.append(f"{package}{version_req}")
814
+ log_info(f"❌ {package} - {feature} MISSING")
815
 
816
+ # Check optional packages
817
+ log_info("\nChecking optional dependencies...")
818
+ for package, info in optional_packages.items():
819
+ feature = info["feature"]
820
+ try:
821
+ __import__(package)
822
+ log_info(f"✅ {package} - {feature} available")
823
+ except ImportError:
824
+ log_info(f"⚠️ {package} - {feature} not available")
 
 
 
 
 
 
 
 
 
 
 
 
825
 
826
+ # Check import order for optimal performance
827
+ if "transformers" in package_versions and "unsloth" in package_versions:
828
+ try:
829
+ import sys
830
+ modules = list(sys.modules.keys())
831
+ transformers_idx = modules.index("transformers")
832
+ unsloth_idx = modules.index("unsloth")
833
+
834
+ if transformers_idx < unsloth_idx:
835
+ order_issue = "⚠️ For optimal performance, import unsloth before transformers"
836
+ order_issues.append(order_issue)
837
+ log_info(order_issue)
838
+ else:
839
+ log_info("✅ Import order: unsloth before transformers (optimal)")
840
+ except (ValueError, IndexError) as e:
841
+ log_info(f"⚠️ Could not verify import order: {str(e)}")
842
 
843
  # Report missing required packages
844
  if missing_packages:
845
+ log_info("\n❌ Critical dependencies missing:")
846
  for pkg in missing_packages:
847
+ log_info(f" - {pkg}")
848
+ log_info("Please install missing dependencies with:")
849
+ log_info(f" pip install {' '.join(missing_packages)}")
850
  return False
851
 
852
+ log_info("\n✅ All required dependencies satisfied!")
 
 
 
853
  return True
854
 
855
+ def get_config_value(config, path, default=None):
856
+ """
857
+ Safely get a nested value from a config dictionary using a dot-separated path.
858
+
859
+ Args:
860
+ config: The configuration dictionary
861
+ path: Dot-separated path to the value (e.g., "training.optimizer.lr")
862
+ default: Default value to return if path doesn't exist
863
+
864
+ Returns:
865
+ The value at the specified path or the default value
866
+ """
867
+ if not config:
868
+ return default
869
+
870
+ parts = path.split('.')
871
+ current = config
872
+
873
+ for part in parts:
874
+ if isinstance(current, dict) and part in current:
875
+ current = current[part]
876
+ else:
877
+ return default
878
+
879
+ return current
880
+
881
  def update_huggingface_space():
882
  """Update the Hugging Face Space with the current code."""
883
  log_info("Updating Hugging Face Space...")
 
955
  logger.warning(f"Error validating Hugging Face credentials: {str(e)}")
956
  return False
957
 
958
+ def setup_environment(args):
959
+ """
960
+ Set up the training environment including logging, seed, and configurations.
961
+
962
+ Args:
963
+ args: Command line arguments
964
+
965
+ Returns:
966
+ tuple: (transformers_config, seed) - The loaded configuration and random seed
967
+ """
968
+ # Load environment variables first
969
+ load_env_variables()
970
+
971
+ # Set random seed for reproducibility
972
+ seed = args.seed if args.seed is not None else int(time.time()) % 10000
973
+ set_seed(seed)
974
+ log_info(f"Using random seed: {seed}")
975
+
976
+ # Load configuration
977
+ base_path = os.path.dirname(os.path.abspath(__file__))
978
+ config_file = args.config_file or os.path.join(base_path, "transformers_config.json")
979
+
980
+ if not os.path.exists(config_file):
981
+ raise FileNotFoundError(f"Config file not found: {config_file}")
982
+
983
+ log_info(f"Loading configuration from {config_file}")
984
+ transformers_config = load_configs(config_file)
985
+
986
+ # Set up hardware environment variables if CUDA is available
987
+ if CUDA_AVAILABLE:
988
+ memory_fraction = get_config_value(transformers_config, "hardware.system_settings.cuda_memory_fraction", 0.75)
989
+ if memory_fraction < 1.0:
990
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
991
+ log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
992
+
993
+ # Check dependencies before proceeding
994
+ if not check_dependencies():
995
+ raise RuntimeError("Critical dependencies missing")
996
+
997
+ return transformers_config, seed
998
+
999
+ def setup_model_and_tokenizer(config):
1000
+ """
1001
+ Load and configure the model and tokenizer.
1002
+
1003
+ Args:
1004
+ config: Complete configuration dictionary
1005
+
1006
+ Returns:
1007
+ tuple: (model, tokenizer) - The loaded model and tokenizer
1008
+ """
1009
+ log_info("Loading model and tokenizer...")
1010
+ model, tokenizer = load_model_and_tokenizer(config)
1011
+
1012
+ if model is None or tokenizer is None:
1013
+ raise ValueError("Failed to load model or tokenizer")
1014
+
1015
+ log_info(f"Model loaded successfully: {model.__class__.__name__}")
1016
+ log_info(f"Tokenizer loaded: {tokenizer.__class__.__name__} (vocab size: {tokenizer.vocab_size})")
1017
+
1018
+ return model, tokenizer
1019
+
1020
+ def setup_dataset_and_collator(config, tokenizer):
1021
+ """
1022
+ Load and configure the dataset and data collator.
1023
+
1024
+ Args:
1025
+ config: Complete configuration dictionary
1026
+ tokenizer: The tokenizer for the data collator
1027
+
1028
+ Returns:
1029
+ tuple: (dataset, data_collator) - The loaded dataset and configured data collator
1030
+ """
1031
+ dataset_config = get_config_value(config, "dataset", {})
1032
+
1033
+ log_info("Loading dataset...")
1034
+ dataset = load_dataset_with_mapping(dataset_config)
1035
+
1036
+ # Validate dataset
1037
+ if dataset is None:
1038
+ raise ValueError("Dataset is None! Cannot proceed with training.")
1039
+
1040
+ if not hasattr(dataset, '__len__') or len(dataset) == 0:
1041
+ raise ValueError("Dataset is empty! Cannot proceed with training.")
1042
+
1043
+ log_info(f"Dataset loaded with {len(dataset)} examples")
1044
+
1045
+ # Create data collator
1046
+ data_collator = SimpleDataCollator(tokenizer, dataset_config)
1047
+
1048
+ return dataset, data_collator
1049
+
1050
+ def create_training_arguments(config, dataset):
1051
+ """
1052
+ Create and configure training arguments for the Trainer.
1053
+
1054
+ Args:
1055
+ config: Complete configuration dictionary
1056
+ dataset: The dataset to determine total steps
1057
+
1058
+ Returns:
1059
+ TrainingArguments: Configured training arguments
1060
+ """
1061
+ # Extract configuration sections
1062
+ training_config = get_config_value(config, "training", {})
1063
+ hardware_config = get_config_value(config, "hardware", {})
1064
+ huggingface_config = get_config_value(config, "huggingface_hub", {})
1065
+ distributed_config = get_config_value(config, "distributed_training", {})
1066
+
1067
+ # Extract key training parameters
1068
+ per_device_batch_size = get_config_value(training_config, "per_device_train_batch_size", 4)
1069
+ gradient_accumulation_steps = get_config_value(training_config, "gradient_accumulation_steps", 8)
1070
+ learning_rate = get_config_value(training_config, "learning_rate", 2e-5)
1071
+ num_train_epochs = get_config_value(training_config, "num_train_epochs", 3)
1072
+
1073
+ # Extract hardware settings
1074
+ dataloader_workers = get_config_value(hardware_config, "system_settings.dataloader_num_workers",
1075
+ get_config_value(distributed_config, "dataloader_num_workers", 2))
1076
+ pin_memory = get_config_value(hardware_config, "system_settings.dataloader_pin_memory", True)
1077
+
1078
+ # BF16/FP16 settings - ensure only one is enabled
1079
+ use_bf16 = get_config_value(training_config, "bf16", False)
1080
+ use_fp16 = get_config_value(training_config, "fp16", False) if not use_bf16 else False
1081
+
1082
+ # Configure distributed training
1083
+ fsdp_config = get_config_value(distributed_config, "fsdp_config", {})
1084
+ fsdp_enabled = get_config_value(fsdp_config, "enabled", False)
1085
+
1086
+ ddp_config = get_config_value(distributed_config, "ddp_config", {})
1087
+ ddp_find_unused_parameters = get_config_value(ddp_config, "find_unused_parameters", False)
1088
+
1089
+ # Set up FSDP args if enabled
1090
+ fsdp_args = None
1091
+ if fsdp_enabled and NUM_GPUS > 1:
1092
+ from accelerate import FullyShardedDataParallelPlugin
1093
+ from torch.distributed.fsdp.fully_sharded_data_parallel import (
1094
+ FullOptimStateDictConfig, FullStateDictConfig
1095
+ )
1096
+
1097
+ fsdp_plugin = FullyShardedDataParallelPlugin(
1098
+ sharding_strategy=get_config_value(fsdp_config, "sharding_strategy", "FULL_SHARD"),
1099
+ mixed_precision_policy=get_config_value(fsdp_config, "mixed_precision", "BF16"),
1100
+ state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
1101
+ optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
1102
+ )
1103
+
1104
+ fsdp_args = {
1105
+ "fsdp": fsdp_plugin,
1106
+ "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer", "PhiDecoderLayer"]
1107
+ }
1108
+
1109
+ # Create and return training arguments
1110
+ training_args = TrainingArguments(
1111
+ output_dir=get_config_value(config, "checkpointing.output_dir", "./results"),
1112
+ overwrite_output_dir=True,
1113
+ num_train_epochs=num_train_epochs,
1114
+ per_device_train_batch_size=per_device_batch_size,
1115
+ gradient_accumulation_steps=gradient_accumulation_steps,
1116
+ learning_rate=learning_rate,
1117
+ weight_decay=get_config_value(training_config, "weight_decay", 0.01),
1118
+ max_grad_norm=get_config_value(training_config, "max_grad_norm", 1.0),
1119
+ warmup_ratio=get_config_value(training_config, "warmup_ratio", 0.03),
1120
+ lr_scheduler_type=get_config_value(training_config, "lr_scheduler_type", "cosine"),
1121
+ logging_steps=get_config_value(training_config, "logging_steps", 10),
1122
+ save_strategy=get_config_value(config, "checkpointing.save_strategy", "steps"),
1123
+ save_steps=get_config_value(config, "checkpointing.save_steps", 500),
1124
+ save_total_limit=get_config_value(config, "checkpointing.save_total_limit", 3),
1125
+ bf16=use_bf16,
1126
+ fp16=use_fp16,
1127
+ push_to_hub=get_config_value(huggingface_config, "push_to_hub", False),
1128
+ hub_model_id=get_config_value(huggingface_config, "hub_model_id", None),
1129
+ hub_strategy=get_config_value(huggingface_config, "hub_strategy", "every_save"),
1130
+ hub_private_repo=get_config_value(huggingface_config, "hub_private_repo", True),
1131
+ gradient_checkpointing=get_config_value(training_config, "gradient_checkpointing", True),
1132
+ dataloader_pin_memory=pin_memory,
1133
+ optim=get_config_value(training_config, "optim", "adamw_torch"),
1134
+ ddp_find_unused_parameters=ddp_find_unused_parameters,
1135
+ dataloader_drop_last=False,
1136
+ dataloader_num_workers=dataloader_workers,
1137
+ no_cuda=False if CUDA_AVAILABLE else True,
1138
+ **({} if fsdp_args is None else fsdp_args)
1139
+ )
1140
+
1141
+ log_info("Training arguments created successfully")
1142
+ return training_args
1143
+
1144
+ def configure_custom_dataloader(trainer, dataset, config, training_args):
1145
+ """
1146
+ Configure a custom dataloader for the trainer if needed.
1147
+
1148
+ Args:
1149
+ trainer: The Trainer instance to configure
1150
+ dataset: The dataset to use
1151
+ config: Complete configuration dictionary
1152
+ training_args: The training arguments
1153
+
1154
+ Returns:
1155
+ None (modifies trainer in-place)
1156
+ """
1157
+ dataset_config = get_config_value(config, "dataset", {})
1158
+
1159
+ # Check if we need a custom dataloader
1160
+ if get_config_value(dataset_config, "data_loading.sequential_processing", True):
1161
+ log_info("Using custom sequential dataloader")
1162
+
1163
+ # Create sequential sampler to maintain dataset order
1164
+ sequential_sampler = torch.utils.data.SequentialSampler(dataset)
1165
+ log_info("Sequential sampler created")
1166
+
1167
+ # Define custom dataloader getter
1168
+ def custom_get_train_dataloader():
1169
+ """Create a custom dataloader that maintains dataset order"""
1170
+ # Get configuration values
1171
+ batch_size = training_args.per_device_train_batch_size
1172
+ drop_last = get_config_value(dataset_config, "data_loading.drop_last", False)
1173
+ num_workers = training_args.dataloader_num_workers
1174
+ pin_memory = training_args.dataloader_pin_memory
1175
+ prefetch_factor = get_config_value(dataset_config, "data_loading.prefetch_factor", 2)
1176
+ persistent_workers = get_config_value(dataset_config, "data_loading.persistent_workers", False)
1177
+
1178
+ # Create DataLoader with sequential sampler
1179
+ return DataLoader(
1180
+ dataset,
1181
+ batch_size=batch_size,
1182
+ sampler=sequential_sampler,
1183
+ collate_fn=trainer.data_collator,
1184
+ drop_last=drop_last,
1185
+ num_workers=num_workers,
1186
+ pin_memory=pin_memory,
1187
+ prefetch_factor=prefetch_factor if num_workers > 0 else None,
1188
+ persistent_workers=persistent_workers if num_workers > 0 else False,
1189
+ )
1190
+
1191
+ # Override the default dataloader
1192
+ trainer.get_train_dataloader = custom_get_train_dataloader
1193
+
1194
+ def run_training(trainer, tokenizer, training_args):
1195
+ """
1196
+ Run the training process and handle model saving.
1197
+
1198
+ Args:
1199
+ trainer: Configured Trainer instance
1200
+ tokenizer: The tokenizer to save with the model
1201
+ training_args: Training arguments
1202
+
1203
+ Returns:
1204
+ int: 0 for success, 1 for failure
1205
+ """
1206
+ log_info("Starting training...")
1207
+ trainer.train()
1208
+
1209
+ log_info("Training complete! Saving final model...")
1210
+ trainer.save_model()
1211
+ tokenizer.save_pretrained(training_args.output_dir)
1212
+
1213
+ # Push to Hub if configured
1214
+ if training_args.push_to_hub:
1215
+ log_info(f"Pushing model to Hugging Face Hub: {training_args.hub_model_id}")
1216
+ trainer.push_to_hub()
1217
+
1218
+ log_info("Training completed successfully!")
1219
+ return 0
1220
+
1221
  def main():
1222
+ """
1223
+ Main entry point for the training script.
1224
+
1225
+ Returns:
1226
+ int: 0 for success, non-zero for failure
1227
+ """
1228
  # Set up logging
1229
  logger.info("Starting training process")
1230
 
1231
  try:
1232
+ # Parse command line arguments
 
 
 
 
 
1233
  args = parse_args()
1234
 
1235
+ # Set up environment and load configuration
1236
+ transformers_config, seed = setup_environment(args)
1237
 
1238
+ # Load model and tokenizer
 
 
 
1239
  try:
1240
+ model, tokenizer = setup_model_and_tokenizer(transformers_config)
 
 
 
1241
  except Exception as e:
1242
+ logger.error(f"Error setting up model: {str(e)}")
1243
  return 1
1244
 
1245
+ # Load dataset and create data collator
1246
+ try:
1247
+ dataset, data_collator = setup_dataset_and_collator(transformers_config, tokenizer)
1248
+ except Exception as e:
1249
+ logger.error(f"Error setting up dataset: {str(e)}")
1250
+ return 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1251
 
1252
+ # Configure training arguments
1253
+ try:
1254
+ training_args = create_training_arguments(transformers_config, dataset)
1255
+ except Exception as e:
1256
+ logger.error(f"Error configuring training arguments: {str(e)}")
1257
+ return 1
1258
 
1259
+ # Initialize trainer with callbacks
1260
+ log_info("Initializing Trainer")
1261
+ trainer = Trainer(
1262
+ model=model,
1263
+ args=training_args,
1264
+ train_dataset=dataset,
1265
+ data_collator=data_collator,
1266
+ callbacks=[LoggingCallback(model=model, dataset=dataset)],
1267
+ )
1268
 
1269
+ # Configure custom dataloader if needed
1270
  try:
1271
+ configure_custom_dataloader(trainer, dataset, transformers_config, training_args)
1272
+ except Exception as e:
1273
+ logger.error(f"Error configuring custom dataloader: {str(e)}")
1274
+ return 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1275
 
1276
+ # Run training process
1277
+ try:
1278
+ return run_training(trainer, tokenizer, training_args)
1279
  except Exception as e:
1280
+ logger.error(f"Training failed with error: {str(e)}")
1281
+ # Log GPU memory for debugging
1282
+ log_gpu_memory_usage(label="Error")
1283
+ # Print full stack trace
1284
+ traceback.print_exc()
1285
  return 1
1286
 
1287
  except Exception as e:
1288
  logger.error(f"Error in main function: {str(e)}")
1289
+ traceback.print_exc()
1290
  return 1
1291
 
1292
  if __name__ == "__main__":