Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- run_transformers_training.py +142 -99
run_transformers_training.py
CHANGED
@@ -151,13 +151,55 @@ def load_model_and_tokenizer(config):
|
|
151 |
use_flash_attention = False
|
152 |
logger.warning("Flash attention not available, falling back to standard attention")
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
# Apply Unsloth's training optimizations with config parameters
|
163 |
unsloth_config = config.get("unsloth", {})
|
@@ -332,14 +374,16 @@ class SimpleDataCollator:
|
|
332 |
self.dataset_config = dataset_config
|
333 |
self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
|
334 |
self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
335 |
-
self.prompt_counter = 0
|
336 |
self.paper_counters = {}
|
337 |
self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
|
338 |
-
self.include_metadata =
|
339 |
-
self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True)
|
340 |
-
self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}")
|
341 |
self.roles = dataset_config.get("data_formatting", {}).get("roles", {})
|
342 |
logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
def normalize_conversation(self, conversation):
|
345 |
"""Normalize conversation format to ensure consistent structure."""
|
@@ -353,6 +397,23 @@ class SimpleDataCollator:
|
|
353 |
else:
|
354 |
return []
|
355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
for turn in conversation:
|
357 |
# Skip empty or None entries
|
358 |
if not turn:
|
@@ -406,23 +467,6 @@ class SimpleDataCollator:
|
|
406 |
self.stats["skipped"] += 1
|
407 |
continue
|
408 |
|
409 |
-
# Track paper chunks
|
410 |
-
if paper_id not in self.paper_counters:
|
411 |
-
self.paper_counters[paper_id] = 0
|
412 |
-
self.paper_counters[paper_id] += 1
|
413 |
-
|
414 |
-
# Add metadata if configured
|
415 |
-
if self.include_metadata:
|
416 |
-
# Format metadata according to configured format
|
417 |
-
metadata_content = self.metadata_format.format(
|
418 |
-
paper_id=paper_id,
|
419 |
-
chunk_number=self.paper_counters[paper_id]
|
420 |
-
)
|
421 |
-
|
422 |
-
# Add as system message if not already in conversation
|
423 |
-
if not any(msg.get("role") == "system" for msg in conversation):
|
424 |
-
conversation = [{"role": "system", "content": metadata_content}] + conversation
|
425 |
-
|
426 |
# Format conversation with research introduction and chunk info
|
427 |
formatted_content = format_phi_chat(conversation, self.dataset_config)
|
428 |
|
@@ -433,6 +477,7 @@ class SimpleDataCollator:
|
|
433 |
truncation=True,
|
434 |
max_length=self.max_seq_length,
|
435 |
return_tensors=None,
|
|
|
436 |
)
|
437 |
|
438 |
if len(inputs["input_ids"]) > 0:
|
@@ -450,7 +495,7 @@ class SimpleDataCollator:
|
|
450 |
log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
|
451 |
if self.stats["processed"] <= log_samples:
|
452 |
logger.info(f"Example {self.stats['processed']} format:")
|
453 |
-
logger.info(f"Paper ID: {paper_id}
|
454 |
logger.info(f"Token count: {len(inputs['input_ids'])}")
|
455 |
logger.info(f"Content preview:\n{formatted_content[:500]}...")
|
456 |
logger.info(f"Conversation structure: {conversation[:2]}...")
|
@@ -464,6 +509,7 @@ class SimpleDataCollator:
|
|
464 |
|
465 |
if not batch["input_ids"]:
|
466 |
logger.warning("Empty batch, returning dummy tensors")
|
|
|
467 |
return {
|
468 |
"input_ids": torch.zeros((1, 1), dtype=torch.long),
|
469 |
"attention_mask": torch.zeros((1, 1), dtype=torch.long),
|
@@ -480,8 +526,8 @@ class SimpleDataCollator:
|
|
480 |
batch["attention_mask"][i].extend([0] * padding_length)
|
481 |
batch["labels"][i].extend([-100] * padding_length)
|
482 |
|
483 |
-
# Convert to tensors
|
484 |
-
batch = {k: torch.tensor(v) for k, v in batch.items()}
|
485 |
|
486 |
# Log stats periodically
|
487 |
log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
|
@@ -601,6 +647,18 @@ def main():
|
|
601 |
set_seed(seed)
|
602 |
logger.info(f"Set random seed to {seed}")
|
603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
try:
|
605 |
model, tokenizer = load_model_and_tokenizer(model_config)
|
606 |
logger.info("Model and tokenizer loaded successfully")
|
@@ -612,7 +670,7 @@ def main():
|
|
612 |
except Exception as e:
|
613 |
logger.error(f"Error loading dataset: {e}")
|
614 |
return 1
|
615 |
-
|
616 |
# Create data collator
|
617 |
data_collator = SimpleDataCollator(tokenizer, dataset_config)
|
618 |
|
@@ -627,6 +685,13 @@ def main():
|
|
627 |
if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
|
628 |
logger.info(f"Step {state.global_step}: Loss {state.log_history[-1]['loss'] if state.log_history else 'N/A'}")
|
629 |
self.last_log_time = current_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
630 |
|
631 |
# Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
|
632 |
use_bf16 = model_config.get("bf16", False) or model_config.get("torch_dtype", "") == "bfloat16"
|
@@ -658,95 +723,73 @@ def main():
|
|
658 |
report_to="tensorboard",
|
659 |
remove_unused_columns=False, # Keep all columns
|
660 |
gradient_checkpointing=model_config.get("training", {}).get("gradient_checkpointing", True),
|
661 |
-
dataloader_pin_memory=
|
662 |
optim=model_config.get("training", {}).get("optim", "adamw_torch"),
|
663 |
ddp_find_unused_parameters=False, # Improve distributed training efficiency
|
664 |
dataloader_drop_last=False, # Process all examples
|
665 |
dataloader_num_workers=4, # Sequential data loading
|
|
|
666 |
)
|
667 |
|
668 |
-
#
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
|
|
|
|
|
684 |
trainer = Trainer(
|
685 |
model=model,
|
686 |
args=training_args,
|
687 |
-
|
|
|
688 |
data_collator=data_collator,
|
689 |
callbacks=[LoggingCallback()]
|
690 |
)
|
691 |
|
692 |
-
# Override the default data loader to disable shuffling
|
693 |
-
# This is necessary because TrainingArguments doesn't have a direct shuffle parameter
|
694 |
-
def get_train_dataloader_no_shuffle():
|
695 |
-
"""Create a train DataLoader with shuffling disabled."""
|
696 |
-
logger.info("Creating train dataloader with sequential sampler (no shuffling)")
|
697 |
-
|
698 |
-
# Create a sequential sampler to ensure dataset is processed in order
|
699 |
-
train_sampler = torch.utils.data.SequentialSampler(dataset)
|
700 |
-
|
701 |
-
return torch.utils.data.DataLoader(
|
702 |
-
dataset,
|
703 |
-
batch_size=training_args.per_device_train_batch_size,
|
704 |
-
sampler=train_sampler, # Use sequential sampler instead of shuffle parameter
|
705 |
-
collate_fn=data_collator,
|
706 |
-
drop_last=False,
|
707 |
-
num_workers=0,
|
708 |
-
pin_memory=False
|
709 |
-
)
|
710 |
-
|
711 |
-
# Replace the default data loader with our non-shuffling version
|
712 |
-
trainer.get_train_dataloader = get_train_dataloader_no_shuffle
|
713 |
-
|
714 |
# Start training
|
715 |
-
logger.info("Starting training")
|
716 |
-
logger.info(f"Processing with batch size = {training_args.per_device_train_batch_size}, each entry processed independently")
|
717 |
-
|
718 |
-
# Create a lock file to indicate training is in progress
|
719 |
-
lock_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TRAINING_IN_PROGRESS.lock")
|
720 |
-
with open(lock_file, "w") as f:
|
721 |
-
f.write(f"Training started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
722 |
-
f.write(f"Expected completion: After {training_args.num_train_epochs} epochs\n")
|
723 |
-
f.write("DO NOT UPDATE OR RESTART THIS SPACE UNTIL TRAINING COMPLETES\n")
|
724 |
-
logger.info(f"Created lock file: {lock_file}")
|
725 |
-
|
726 |
try:
|
727 |
-
trainer.train(
|
728 |
logger.info("Training completed successfully")
|
729 |
|
730 |
-
# Save model
|
731 |
-
|
732 |
-
|
|
|
|
|
|
|
|
|
733 |
trainer.push_to_hub()
|
734 |
-
|
735 |
-
|
736 |
-
logger.info(f"Saving model to {model_config.get('output_dir', './results')}")
|
737 |
-
trainer.save_model()
|
738 |
-
logger.info("Model saved successfully")
|
739 |
except Exception as e:
|
740 |
logger.error(f"Training failed with error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
741 |
raise
|
742 |
-
|
743 |
-
# Remove the lock file when training completes or fails
|
744 |
-
if os.path.exists(lock_file):
|
745 |
-
os.remove(lock_file)
|
746 |
-
logger.info(f"Removed lock file: {lock_file}")
|
747 |
-
|
748 |
-
return 0
|
749 |
-
|
750 |
except Exception as e:
|
751 |
logger.error(f"Error in main training loop: {str(e)}")
|
752 |
return 1
|
|
|
151 |
use_flash_attention = False
|
152 |
logger.warning("Flash attention not available, falling back to standard attention")
|
153 |
|
154 |
+
# First detect if we have a GPU
|
155 |
+
if torch.cuda.is_available():
|
156 |
+
logger.info(f"CUDA available, found {torch.cuda.device_count()} GPU(s)")
|
157 |
+
device_map = "auto"
|
158 |
+
else:
|
159 |
+
logger.warning("No CUDA available, falling back to CPU")
|
160 |
+
device_map = {"": "cpu"} # Force CPU placement
|
161 |
+
|
162 |
+
# Set default dtype for better numerics
|
163 |
+
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
164 |
+
# Use bfloat16 for Ampere or newer
|
165 |
+
dtype = torch.bfloat16
|
166 |
+
logger.info("Using bfloat16 precision (Ampere+ GPU)")
|
167 |
+
elif torch.cuda.is_available():
|
168 |
+
# Use float16 for older GPUs
|
169 |
+
dtype = torch.float16
|
170 |
+
logger.info("Using float16 precision (pre-Ampere GPU)")
|
171 |
+
else:
|
172 |
+
# CPU, use default dtype
|
173 |
+
dtype = None
|
174 |
+
logger.info("Using default precision (CPU)")
|
175 |
+
|
176 |
+
# Load model with proper error handling for out-of-memory
|
177 |
+
try:
|
178 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
179 |
+
model_name=model_name,
|
180 |
+
max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
|
181 |
+
dtype=dtype,
|
182 |
+
device_map=device_map,
|
183 |
+
# Don't explicitly use flash attention config here, let Unsloth handle it
|
184 |
+
)
|
185 |
+
except RuntimeError as e:
|
186 |
+
if "CUDA out of memory" in str(e):
|
187 |
+
logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.")
|
188 |
+
raise
|
189 |
+
else:
|
190 |
+
# Try again with CPU placement to see if it's a memory issue
|
191 |
+
logger.warning(f"Error loading model on default device: {str(e)}")
|
192 |
+
logger.warning("Attempting to load with device_map='cpu' and no specific dtype")
|
193 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
194 |
+
model_name=model_name,
|
195 |
+
max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
|
196 |
+
dtype=None,
|
197 |
+
device_map={"": "cpu"},
|
198 |
+
)
|
199 |
+
logger.warning("Model loaded on CPU. Training will be very slow.")
|
200 |
+
|
201 |
+
# Ensure model and optimizer init is on the same device
|
202 |
+
logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}")
|
203 |
|
204 |
# Apply Unsloth's training optimizations with config parameters
|
205 |
unsloth_config = config.get("unsloth", {})
|
|
|
374 |
self.dataset_config = dataset_config
|
375 |
self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
|
376 |
self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
|
|
377 |
self.paper_counters = {}
|
378 |
self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
|
379 |
+
self.include_metadata = False # Disable automatic metadata inclusion as it's already in content
|
|
|
|
|
380 |
self.roles = dataset_config.get("data_formatting", {}).get("roles", {})
|
381 |
logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
|
382 |
+
logger.info("Metadata handling disabled - using metadata from content field")
|
383 |
+
|
384 |
+
# Check if we're on GPU
|
385 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
386 |
+
logger.info(f"SimpleDataCollator using device: {self.device}")
|
387 |
|
388 |
def normalize_conversation(self, conversation):
|
389 |
"""Normalize conversation format to ensure consistent structure."""
|
|
|
397 |
else:
|
398 |
return []
|
399 |
|
400 |
+
# Get introductory message if present (should be first and without chunk number)
|
401 |
+
intro_msg = None
|
402 |
+
for i, turn in enumerate(conversation):
|
403 |
+
if isinstance(turn, dict) and turn.get('content') and "[RESEARCH INTRODUCTION]" in turn.get('content', ''):
|
404 |
+
intro_msg = turn
|
405 |
+
break
|
406 |
+
|
407 |
+
# Process introduction message first if found
|
408 |
+
if intro_msg:
|
409 |
+
normalized.append({
|
410 |
+
"role": "system",
|
411 |
+
"content": intro_msg.get('content', '')
|
412 |
+
})
|
413 |
+
# Remove intro from further processing
|
414 |
+
conversation = [t for t in conversation if t != intro_msg]
|
415 |
+
|
416 |
+
# Process remaining messages
|
417 |
for turn in conversation:
|
418 |
# Skip empty or None entries
|
419 |
if not turn:
|
|
|
467 |
self.stats["skipped"] += 1
|
468 |
continue
|
469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
# Format conversation with research introduction and chunk info
|
471 |
formatted_content = format_phi_chat(conversation, self.dataset_config)
|
472 |
|
|
|
477 |
truncation=True,
|
478 |
max_length=self.max_seq_length,
|
479 |
return_tensors=None,
|
480 |
+
padding=False, # Don't pad here, we'll pad the batch later
|
481 |
)
|
482 |
|
483 |
if len(inputs["input_ids"]) > 0:
|
|
|
495 |
log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
|
496 |
if self.stats["processed"] <= log_samples:
|
497 |
logger.info(f"Example {self.stats['processed']} format:")
|
498 |
+
logger.info(f"Paper ID: {paper_id}")
|
499 |
logger.info(f"Token count: {len(inputs['input_ids'])}")
|
500 |
logger.info(f"Content preview:\n{formatted_content[:500]}...")
|
501 |
logger.info(f"Conversation structure: {conversation[:2]}...")
|
|
|
509 |
|
510 |
if not batch["input_ids"]:
|
511 |
logger.warning("Empty batch, returning dummy tensors")
|
512 |
+
# Return tensors on the right device
|
513 |
return {
|
514 |
"input_ids": torch.zeros((1, 1), dtype=torch.long),
|
515 |
"attention_mask": torch.zeros((1, 1), dtype=torch.long),
|
|
|
526 |
batch["attention_mask"][i].extend([0] * padding_length)
|
527 |
batch["labels"][i].extend([-100] * padding_length)
|
528 |
|
529 |
+
# Convert to tensors on CPU first
|
530 |
+
batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
|
531 |
|
532 |
# Log stats periodically
|
533 |
log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
|
|
|
647 |
set_seed(seed)
|
648 |
logger.info(f"Set random seed to {seed}")
|
649 |
|
650 |
+
# Check CUDA and set environment variables for better memory management
|
651 |
+
if torch.cuda.is_available():
|
652 |
+
# Empty CUDA cache
|
653 |
+
torch.cuda.empty_cache()
|
654 |
+
# Set memory management env vars (optional)
|
655 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
656 |
+
# Log memory information
|
657 |
+
for i in range(torch.cuda.device_count()):
|
658 |
+
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
659 |
+
logger.info(f"Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
|
660 |
+
logger.info(f"Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")
|
661 |
+
|
662 |
try:
|
663 |
model, tokenizer = load_model_and_tokenizer(model_config)
|
664 |
logger.info("Model and tokenizer loaded successfully")
|
|
|
670 |
except Exception as e:
|
671 |
logger.error(f"Error loading dataset: {e}")
|
672 |
return 1
|
673 |
+
|
674 |
# Create data collator
|
675 |
data_collator = SimpleDataCollator(tokenizer, dataset_config)
|
676 |
|
|
|
685 |
if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
|
686 |
logger.info(f"Step {state.global_step}: Loss {state.log_history[-1]['loss'] if state.log_history else 'N/A'}")
|
687 |
self.last_log_time = current_time
|
688 |
+
|
689 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
690 |
+
logger.info("Training is starting...")
|
691 |
+
# Log memory information
|
692 |
+
if torch.cuda.is_available():
|
693 |
+
for i in range(torch.cuda.device_count()):
|
694 |
+
logger.info(f"GPU {i} Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
|
695 |
|
696 |
# Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
|
697 |
use_bf16 = model_config.get("bf16", False) or model_config.get("torch_dtype", "") == "bfloat16"
|
|
|
723 |
report_to="tensorboard",
|
724 |
remove_unused_columns=False, # Keep all columns
|
725 |
gradient_checkpointing=model_config.get("training", {}).get("gradient_checkpointing", True),
|
726 |
+
dataloader_pin_memory=True, # Keep data in pinned memory for faster transfer
|
727 |
optim=model_config.get("training", {}).get("optim", "adamw_torch"),
|
728 |
ddp_find_unused_parameters=False, # Improve distributed training efficiency
|
729 |
dataloader_drop_last=False, # Process all examples
|
730 |
dataloader_num_workers=4, # Sequential data loading
|
731 |
+
no_cuda=False if torch.cuda.is_available() else True, # Use CUDA if available
|
732 |
)
|
733 |
|
734 |
+
# Custom dataloader to ensure no shuffling of dataset
|
735 |
+
# This preserves the order of chunks in papers
|
736 |
+
def get_train_dataloader_no_shuffle():
|
737 |
+
logger.info("Creating data loader with sequential sampler to maintain paper order")
|
738 |
+
if getattr(training_args, "no_cuda", False):
|
739 |
+
batch_size = training_args.per_device_train_batch_size
|
740 |
+
else:
|
741 |
+
batch_size = max(training_args.per_device_train_batch_size * torch.cuda.device_count(), 1)
|
742 |
+
|
743 |
+
# Use sequential sampler to preserve order
|
744 |
+
sequential_sampler = torch.utils.data.SequentialSampler(dataset["train"])
|
745 |
+
logger.info(f"Using sequential sampler for batch size {batch_size}")
|
746 |
+
|
747 |
+
return torch.utils.data.DataLoader(
|
748 |
+
dataset["train"],
|
749 |
+
batch_size=batch_size,
|
750 |
+
sampler=sequential_sampler,
|
751 |
+
collate_fn=data_collator,
|
752 |
+
drop_last=training_args.dataloader_drop_last,
|
753 |
+
num_workers=training_args.dataloader_num_workers,
|
754 |
+
pin_memory=training_args.dataloader_pin_memory,
|
755 |
+
)
|
756 |
|
757 |
+
# Set up trainer with custom dataloader
|
758 |
+
logger.info("Initializing Trainer")
|
759 |
trainer = Trainer(
|
760 |
model=model,
|
761 |
args=training_args,
|
762 |
+
get_train_dataloader=get_train_dataloader_no_shuffle,
|
763 |
+
tokenizer=tokenizer,
|
764 |
data_collator=data_collator,
|
765 |
callbacks=[LoggingCallback()]
|
766 |
)
|
767 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
768 |
# Start training
|
769 |
+
logger.info("Starting training process")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
770 |
try:
|
771 |
+
trainer.train()
|
772 |
logger.info("Training completed successfully")
|
773 |
|
774 |
+
# Save the final model
|
775 |
+
logger.info("Saving final model")
|
776 |
+
trainer.save_model()
|
777 |
+
|
778 |
+
# Push to hub if enabled
|
779 |
+
if model_config.get("huggingface_hub", {}).get("push_to_hub", False):
|
780 |
+
logger.info("Pushing model to Hugging Face Hub")
|
781 |
trainer.push_to_hub()
|
782 |
+
|
783 |
+
return 0
|
|
|
|
|
|
|
784 |
except Exception as e:
|
785 |
logger.error(f"Training failed with error: {str(e)}")
|
786 |
+
# Log CUDA memory info if available
|
787 |
+
if torch.cuda.is_available():
|
788 |
+
for i in range(torch.cuda.device_count()):
|
789 |
+
logger.info(f"GPU {i} Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
|
790 |
+
logger.info(f"GPU {i} Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")
|
791 |
raise
|
792 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
except Exception as e:
|
794 |
logger.error(f"Error in main training loop: {str(e)}")
|
795 |
return 1
|