George-API commited on
Commit
0364d5c
·
verified ·
1 Parent(s): 9f8478c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. run_transformers_training.py +142 -99
run_transformers_training.py CHANGED
@@ -151,13 +151,55 @@ def load_model_and_tokenizer(config):
151
  use_flash_attention = False
152
  logger.warning("Flash attention not available, falling back to standard attention")
153
 
154
- model, tokenizer = FastLanguageModel.from_pretrained(
155
- model_name=model_name,
156
- max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
157
- dtype=None, # Let Unsloth choose optimal dtype
158
- device_map="auto",
159
- # Don't explicitly use flash attention config here, let Unsloth handle it
160
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  # Apply Unsloth's training optimizations with config parameters
163
  unsloth_config = config.get("unsloth", {})
@@ -332,14 +374,16 @@ class SimpleDataCollator:
332
  self.dataset_config = dataset_config
333
  self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
334
  self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
335
- self.prompt_counter = 0
336
  self.paper_counters = {}
337
  self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
338
- self.include_metadata = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_paper_id", True)
339
- self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True)
340
- self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}")
341
  self.roles = dataset_config.get("data_formatting", {}).get("roles", {})
342
  logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
 
 
 
 
 
343
 
344
  def normalize_conversation(self, conversation):
345
  """Normalize conversation format to ensure consistent structure."""
@@ -353,6 +397,23 @@ class SimpleDataCollator:
353
  else:
354
  return []
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  for turn in conversation:
357
  # Skip empty or None entries
358
  if not turn:
@@ -406,23 +467,6 @@ class SimpleDataCollator:
406
  self.stats["skipped"] += 1
407
  continue
408
 
409
- # Track paper chunks
410
- if paper_id not in self.paper_counters:
411
- self.paper_counters[paper_id] = 0
412
- self.paper_counters[paper_id] += 1
413
-
414
- # Add metadata if configured
415
- if self.include_metadata:
416
- # Format metadata according to configured format
417
- metadata_content = self.metadata_format.format(
418
- paper_id=paper_id,
419
- chunk_number=self.paper_counters[paper_id]
420
- )
421
-
422
- # Add as system message if not already in conversation
423
- if not any(msg.get("role") == "system" for msg in conversation):
424
- conversation = [{"role": "system", "content": metadata_content}] + conversation
425
-
426
  # Format conversation with research introduction and chunk info
427
  formatted_content = format_phi_chat(conversation, self.dataset_config)
428
 
@@ -433,6 +477,7 @@ class SimpleDataCollator:
433
  truncation=True,
434
  max_length=self.max_seq_length,
435
  return_tensors=None,
 
436
  )
437
 
438
  if len(inputs["input_ids"]) > 0:
@@ -450,7 +495,7 @@ class SimpleDataCollator:
450
  log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
451
  if self.stats["processed"] <= log_samples:
452
  logger.info(f"Example {self.stats['processed']} format:")
453
- logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}")
454
  logger.info(f"Token count: {len(inputs['input_ids'])}")
455
  logger.info(f"Content preview:\n{formatted_content[:500]}...")
456
  logger.info(f"Conversation structure: {conversation[:2]}...")
@@ -464,6 +509,7 @@ class SimpleDataCollator:
464
 
465
  if not batch["input_ids"]:
466
  logger.warning("Empty batch, returning dummy tensors")
 
467
  return {
468
  "input_ids": torch.zeros((1, 1), dtype=torch.long),
469
  "attention_mask": torch.zeros((1, 1), dtype=torch.long),
@@ -480,8 +526,8 @@ class SimpleDataCollator:
480
  batch["attention_mask"][i].extend([0] * padding_length)
481
  batch["labels"][i].extend([-100] * padding_length)
482
 
483
- # Convert to tensors
484
- batch = {k: torch.tensor(v) for k, v in batch.items()}
485
 
486
  # Log stats periodically
487
  log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
@@ -601,6 +647,18 @@ def main():
601
  set_seed(seed)
602
  logger.info(f"Set random seed to {seed}")
603
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  try:
605
  model, tokenizer = load_model_and_tokenizer(model_config)
606
  logger.info("Model and tokenizer loaded successfully")
@@ -612,7 +670,7 @@ def main():
612
  except Exception as e:
613
  logger.error(f"Error loading dataset: {e}")
614
  return 1
615
-
616
  # Create data collator
617
  data_collator = SimpleDataCollator(tokenizer, dataset_config)
618
 
@@ -627,6 +685,13 @@ def main():
627
  if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
628
  logger.info(f"Step {state.global_step}: Loss {state.log_history[-1]['loss'] if state.log_history else 'N/A'}")
629
  self.last_log_time = current_time
 
 
 
 
 
 
 
630
 
631
  # Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
632
  use_bf16 = model_config.get("bf16", False) or model_config.get("torch_dtype", "") == "bfloat16"
@@ -658,95 +723,73 @@ def main():
658
  report_to="tensorboard",
659
  remove_unused_columns=False, # Keep all columns
660
  gradient_checkpointing=model_config.get("training", {}).get("gradient_checkpointing", True),
661
- dataloader_pin_memory=False, # Reduce memory usage
662
  optim=model_config.get("training", {}).get("optim", "adamw_torch"),
663
  ddp_find_unused_parameters=False, # Improve distributed training efficiency
664
  dataloader_drop_last=False, # Process all examples
665
  dataloader_num_workers=4, # Sequential data loading
 
666
  )
667
 
668
- # Create a sequential sampler to ensure dataset is processed in order
669
- logger.info("Creating sequential sampler to maintain dataset order")
670
-
671
- # Create trainer with callback
672
- logger.info("Creating trainer")
673
-
674
- # Check if we should resume from checkpoint
675
- resume_from_checkpoint = False
676
- output_dir = model_config.get("output_dir", "./results")
677
- if os.path.exists(output_dir):
678
- checkpoints = [folder for folder in os.listdir(output_dir) if folder.startswith("checkpoint-")]
679
- if checkpoints:
680
- latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[1]))
681
- resume_from_checkpoint = os.path.join(output_dir, latest_checkpoint)
682
- logger.info(f"Found checkpoint: {resume_from_checkpoint}. Training will resume from this point.")
 
 
 
 
 
 
 
683
 
 
 
684
  trainer = Trainer(
685
  model=model,
686
  args=training_args,
687
- train_dataset=dataset,
 
688
  data_collator=data_collator,
689
  callbacks=[LoggingCallback()]
690
  )
691
 
692
- # Override the default data loader to disable shuffling
693
- # This is necessary because TrainingArguments doesn't have a direct shuffle parameter
694
- def get_train_dataloader_no_shuffle():
695
- """Create a train DataLoader with shuffling disabled."""
696
- logger.info("Creating train dataloader with sequential sampler (no shuffling)")
697
-
698
- # Create a sequential sampler to ensure dataset is processed in order
699
- train_sampler = torch.utils.data.SequentialSampler(dataset)
700
-
701
- return torch.utils.data.DataLoader(
702
- dataset,
703
- batch_size=training_args.per_device_train_batch_size,
704
- sampler=train_sampler, # Use sequential sampler instead of shuffle parameter
705
- collate_fn=data_collator,
706
- drop_last=False,
707
- num_workers=0,
708
- pin_memory=False
709
- )
710
-
711
- # Replace the default data loader with our non-shuffling version
712
- trainer.get_train_dataloader = get_train_dataloader_no_shuffle
713
-
714
  # Start training
715
- logger.info("Starting training")
716
- logger.info(f"Processing with batch size = {training_args.per_device_train_batch_size}, each entry processed independently")
717
-
718
- # Create a lock file to indicate training is in progress
719
- lock_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TRAINING_IN_PROGRESS.lock")
720
- with open(lock_file, "w") as f:
721
- f.write(f"Training started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
722
- f.write(f"Expected completion: After {training_args.num_train_epochs} epochs\n")
723
- f.write("DO NOT UPDATE OR RESTART THIS SPACE UNTIL TRAINING COMPLETES\n")
724
- logger.info(f"Created lock file: {lock_file}")
725
-
726
  try:
727
- trainer.train(resume_from_checkpoint=resume_from_checkpoint)
728
  logger.info("Training completed successfully")
729
 
730
- # Save model
731
- if model_config.get("push_to_hub", False):
732
- logger.info(f"Pushing model to hub: {model_config.get('hub_model_id')}")
 
 
 
 
733
  trainer.push_to_hub()
734
- logger.info("Model pushed to hub successfully")
735
- else:
736
- logger.info(f"Saving model to {model_config.get('output_dir', './results')}")
737
- trainer.save_model()
738
- logger.info("Model saved successfully")
739
  except Exception as e:
740
  logger.error(f"Training failed with error: {str(e)}")
 
 
 
 
 
741
  raise
742
- finally:
743
- # Remove the lock file when training completes or fails
744
- if os.path.exists(lock_file):
745
- os.remove(lock_file)
746
- logger.info(f"Removed lock file: {lock_file}")
747
-
748
- return 0
749
-
750
  except Exception as e:
751
  logger.error(f"Error in main training loop: {str(e)}")
752
  return 1
 
151
  use_flash_attention = False
152
  logger.warning("Flash attention not available, falling back to standard attention")
153
 
154
+ # First detect if we have a GPU
155
+ if torch.cuda.is_available():
156
+ logger.info(f"CUDA available, found {torch.cuda.device_count()} GPU(s)")
157
+ device_map = "auto"
158
+ else:
159
+ logger.warning("No CUDA available, falling back to CPU")
160
+ device_map = {"": "cpu"} # Force CPU placement
161
+
162
+ # Set default dtype for better numerics
163
+ if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
164
+ # Use bfloat16 for Ampere or newer
165
+ dtype = torch.bfloat16
166
+ logger.info("Using bfloat16 precision (Ampere+ GPU)")
167
+ elif torch.cuda.is_available():
168
+ # Use float16 for older GPUs
169
+ dtype = torch.float16
170
+ logger.info("Using float16 precision (pre-Ampere GPU)")
171
+ else:
172
+ # CPU, use default dtype
173
+ dtype = None
174
+ logger.info("Using default precision (CPU)")
175
+
176
+ # Load model with proper error handling for out-of-memory
177
+ try:
178
+ model, tokenizer = FastLanguageModel.from_pretrained(
179
+ model_name=model_name,
180
+ max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
181
+ dtype=dtype,
182
+ device_map=device_map,
183
+ # Don't explicitly use flash attention config here, let Unsloth handle it
184
+ )
185
+ except RuntimeError as e:
186
+ if "CUDA out of memory" in str(e):
187
+ logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.")
188
+ raise
189
+ else:
190
+ # Try again with CPU placement to see if it's a memory issue
191
+ logger.warning(f"Error loading model on default device: {str(e)}")
192
+ logger.warning("Attempting to load with device_map='cpu' and no specific dtype")
193
+ model, tokenizer = FastLanguageModel.from_pretrained(
194
+ model_name=model_name,
195
+ max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
196
+ dtype=None,
197
+ device_map={"": "cpu"},
198
+ )
199
+ logger.warning("Model loaded on CPU. Training will be very slow.")
200
+
201
+ # Ensure model and optimizer init is on the same device
202
+ logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}")
203
 
204
  # Apply Unsloth's training optimizations with config parameters
205
  unsloth_config = config.get("unsloth", {})
 
374
  self.dataset_config = dataset_config
375
  self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
376
  self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
 
377
  self.paper_counters = {}
378
  self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
379
+ self.include_metadata = False # Disable automatic metadata inclusion as it's already in content
 
 
380
  self.roles = dataset_config.get("data_formatting", {}).get("roles", {})
381
  logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
382
+ logger.info("Metadata handling disabled - using metadata from content field")
383
+
384
+ # Check if we're on GPU
385
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
386
+ logger.info(f"SimpleDataCollator using device: {self.device}")
387
 
388
  def normalize_conversation(self, conversation):
389
  """Normalize conversation format to ensure consistent structure."""
 
397
  else:
398
  return []
399
 
400
+ # Get introductory message if present (should be first and without chunk number)
401
+ intro_msg = None
402
+ for i, turn in enumerate(conversation):
403
+ if isinstance(turn, dict) and turn.get('content') and "[RESEARCH INTRODUCTION]" in turn.get('content', ''):
404
+ intro_msg = turn
405
+ break
406
+
407
+ # Process introduction message first if found
408
+ if intro_msg:
409
+ normalized.append({
410
+ "role": "system",
411
+ "content": intro_msg.get('content', '')
412
+ })
413
+ # Remove intro from further processing
414
+ conversation = [t for t in conversation if t != intro_msg]
415
+
416
+ # Process remaining messages
417
  for turn in conversation:
418
  # Skip empty or None entries
419
  if not turn:
 
467
  self.stats["skipped"] += 1
468
  continue
469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  # Format conversation with research introduction and chunk info
471
  formatted_content = format_phi_chat(conversation, self.dataset_config)
472
 
 
477
  truncation=True,
478
  max_length=self.max_seq_length,
479
  return_tensors=None,
480
+ padding=False, # Don't pad here, we'll pad the batch later
481
  )
482
 
483
  if len(inputs["input_ids"]) > 0:
 
495
  log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
496
  if self.stats["processed"] <= log_samples:
497
  logger.info(f"Example {self.stats['processed']} format:")
498
+ logger.info(f"Paper ID: {paper_id}")
499
  logger.info(f"Token count: {len(inputs['input_ids'])}")
500
  logger.info(f"Content preview:\n{formatted_content[:500]}...")
501
  logger.info(f"Conversation structure: {conversation[:2]}...")
 
509
 
510
  if not batch["input_ids"]:
511
  logger.warning("Empty batch, returning dummy tensors")
512
+ # Return tensors on the right device
513
  return {
514
  "input_ids": torch.zeros((1, 1), dtype=torch.long),
515
  "attention_mask": torch.zeros((1, 1), dtype=torch.long),
 
526
  batch["attention_mask"][i].extend([0] * padding_length)
527
  batch["labels"][i].extend([-100] * padding_length)
528
 
529
+ # Convert to tensors on CPU first
530
+ batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
531
 
532
  # Log stats periodically
533
  log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
 
647
  set_seed(seed)
648
  logger.info(f"Set random seed to {seed}")
649
 
650
+ # Check CUDA and set environment variables for better memory management
651
+ if torch.cuda.is_available():
652
+ # Empty CUDA cache
653
+ torch.cuda.empty_cache()
654
+ # Set memory management env vars (optional)
655
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
656
+ # Log memory information
657
+ for i in range(torch.cuda.device_count()):
658
+ logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
659
+ logger.info(f"Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
660
+ logger.info(f"Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")
661
+
662
  try:
663
  model, tokenizer = load_model_and_tokenizer(model_config)
664
  logger.info("Model and tokenizer loaded successfully")
 
670
  except Exception as e:
671
  logger.error(f"Error loading dataset: {e}")
672
  return 1
673
+
674
  # Create data collator
675
  data_collator = SimpleDataCollator(tokenizer, dataset_config)
676
 
 
685
  if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
686
  logger.info(f"Step {state.global_step}: Loss {state.log_history[-1]['loss'] if state.log_history else 'N/A'}")
687
  self.last_log_time = current_time
688
+
689
+ def on_train_begin(self, args, state, control, **kwargs):
690
+ logger.info("Training is starting...")
691
+ # Log memory information
692
+ if torch.cuda.is_available():
693
+ for i in range(torch.cuda.device_count()):
694
+ logger.info(f"GPU {i} Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
695
 
696
  # Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
697
  use_bf16 = model_config.get("bf16", False) or model_config.get("torch_dtype", "") == "bfloat16"
 
723
  report_to="tensorboard",
724
  remove_unused_columns=False, # Keep all columns
725
  gradient_checkpointing=model_config.get("training", {}).get("gradient_checkpointing", True),
726
+ dataloader_pin_memory=True, # Keep data in pinned memory for faster transfer
727
  optim=model_config.get("training", {}).get("optim", "adamw_torch"),
728
  ddp_find_unused_parameters=False, # Improve distributed training efficiency
729
  dataloader_drop_last=False, # Process all examples
730
  dataloader_num_workers=4, # Sequential data loading
731
+ no_cuda=False if torch.cuda.is_available() else True, # Use CUDA if available
732
  )
733
 
734
+ # Custom dataloader to ensure no shuffling of dataset
735
+ # This preserves the order of chunks in papers
736
+ def get_train_dataloader_no_shuffle():
737
+ logger.info("Creating data loader with sequential sampler to maintain paper order")
738
+ if getattr(training_args, "no_cuda", False):
739
+ batch_size = training_args.per_device_train_batch_size
740
+ else:
741
+ batch_size = max(training_args.per_device_train_batch_size * torch.cuda.device_count(), 1)
742
+
743
+ # Use sequential sampler to preserve order
744
+ sequential_sampler = torch.utils.data.SequentialSampler(dataset["train"])
745
+ logger.info(f"Using sequential sampler for batch size {batch_size}")
746
+
747
+ return torch.utils.data.DataLoader(
748
+ dataset["train"],
749
+ batch_size=batch_size,
750
+ sampler=sequential_sampler,
751
+ collate_fn=data_collator,
752
+ drop_last=training_args.dataloader_drop_last,
753
+ num_workers=training_args.dataloader_num_workers,
754
+ pin_memory=training_args.dataloader_pin_memory,
755
+ )
756
 
757
+ # Set up trainer with custom dataloader
758
+ logger.info("Initializing Trainer")
759
  trainer = Trainer(
760
  model=model,
761
  args=training_args,
762
+ get_train_dataloader=get_train_dataloader_no_shuffle,
763
+ tokenizer=tokenizer,
764
  data_collator=data_collator,
765
  callbacks=[LoggingCallback()]
766
  )
767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
  # Start training
769
+ logger.info("Starting training process")
 
 
 
 
 
 
 
 
 
 
770
  try:
771
+ trainer.train()
772
  logger.info("Training completed successfully")
773
 
774
+ # Save the final model
775
+ logger.info("Saving final model")
776
+ trainer.save_model()
777
+
778
+ # Push to hub if enabled
779
+ if model_config.get("huggingface_hub", {}).get("push_to_hub", False):
780
+ logger.info("Pushing model to Hugging Face Hub")
781
  trainer.push_to_hub()
782
+
783
+ return 0
 
 
 
784
  except Exception as e:
785
  logger.error(f"Training failed with error: {str(e)}")
786
+ # Log CUDA memory info if available
787
+ if torch.cuda.is_available():
788
+ for i in range(torch.cuda.device_count()):
789
+ logger.info(f"GPU {i} Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
790
+ logger.info(f"GPU {i} Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")
791
  raise
792
+
 
 
 
 
 
 
 
793
  except Exception as e:
794
  logger.error(f"Error in main training loop: {str(e)}")
795
  return 1