George-API commited on
Commit
b033a7b
·
verified ·
1 Parent(s): a843fab

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. run_transformers_training.py +46 -107
run_transformers_training.py CHANGED
@@ -494,144 +494,84 @@ class SimpleDataCollator:
494
  self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
495
  self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
496
  self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
497
- logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}")
498
- logger.info("Using exact dataset structure without reformatting")
499
-
500
- # Check if we're on GPU
501
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
502
- logger.info(f"SimpleDataCollator using device: {self.device}")
503
-
504
  def __call__(self, features):
505
- """Process examples preserving exact JSONL structure"""
506
  batch = {"input_ids": [], "attention_mask": [], "labels": []}
507
 
508
  for example in features:
509
  try:
510
  # Get ID for logging
511
- paper_id = example.get("article_id", example.get("id", "unknown"))
 
 
 
512
 
513
- # Safely get conversations with explicit None check
514
- raw_conversations = example.get("conversations")
515
- if raw_conversations is None:
516
- logger.warning(f"Conversations is None for example {paper_id}")
517
  self.stats["skipped"] += 1
518
  continue
519
 
520
- # Ensure conversations is a list
521
- if not isinstance(raw_conversations, list):
522
- logger.warning(f"Conversations is not a list for example {paper_id} (type: {type(raw_conversations)})")
 
 
 
523
  self.stats["skipped"] += 1
524
  continue
525
 
526
- # Check for empty conversations list
527
- if not raw_conversations:
528
- logger.warning(f"Empty conversations list for example {paper_id}")
 
 
 
529
  self.stats["skipped"] += 1
530
  continue
531
 
532
- # Extract only the 'content' field from each conversation item
533
  try:
534
- # Convert conversations to the simple format with only content
535
- simplified_conversations = []
536
- for item in raw_conversations:
537
- # Skip None items
538
- if item is None:
539
- logger.warning(f"Skipping None conversation item in example {paper_id}")
540
- continue
541
-
542
- if isinstance(item, dict):
543
- # Get content with explicit None check
544
- content = item.get("content")
545
- if content is not None:
546
- simplified_conversations.append({"role": "user", "content": content})
547
- else:
548
- logger.warning(f"Skipping conversation item with None content in example {paper_id}")
549
- elif isinstance(item, str):
550
- # If it's just a string, treat it as content
551
- simplified_conversations.append({"role": "user", "content": item})
552
- else:
553
- logger.warning(f"Skipping invalid conversation item type: {type(item)} in example {paper_id}")
554
-
555
- # Skip if no valid conversations after filtering
556
- if not simplified_conversations:
557
- logger.warning(f"No valid conversations after filtering for example {paper_id}")
558
- self.stats["skipped"] += 1
559
- continue
560
-
561
- # Log the simplified content for debugging
562
- if len(simplified_conversations) > 0:
563
- first_content = simplified_conversations[0].get("content", "")
564
- if first_content:
565
- logger.debug(f"First content: {first_content[:50]}...")
566
-
567
- # Let tokenizer handle the simplified conversations
568
- try:
569
- inputs = self.tokenizer.apply_chat_template(
570
- simplified_conversations,
571
- return_tensors=None,
572
- add_generation_prompt=False
573
- )
574
- except Exception as chat_error:
575
- # Fallback if apply_chat_template fails
576
- logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)}")
577
-
578
- # Create a basic representation of just the content
579
- conversation_text = ""
580
- for msg in simplified_conversations:
581
- if isinstance(msg, dict) and msg.get("content"):
582
- conversation_text += msg["content"] + "\n\n"
583
-
584
- if not conversation_text:
585
- logger.warning(f"No valid content to tokenize in example {paper_id}")
586
- self.stats["skipped"] += 1
587
- continue
588
-
589
- # Basic tokenization
590
- inputs = self.tokenizer(
591
- conversation_text,
592
- add_special_tokens=True,
593
- return_tensors=None
594
- )
595
-
596
- # Apply length cap if needed
597
- if self.max_seq_length > 0 and len(inputs) > self.max_seq_length:
598
- logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})")
599
- inputs = inputs[:self.max_seq_length]
600
 
601
- # Create attention mask (1 for all tokens)
602
- attention_mask = [1] * len(inputs)
603
 
604
- if len(inputs) > 0:
605
- # For causal language modeling, labels are the same as inputs
606
- labels = inputs.copy()
607
-
608
- batch["input_ids"].append(inputs)
609
  batch["attention_mask"].append(attention_mask)
610
- batch["labels"].append(labels)
611
 
612
  self.stats["processed"] += 1
613
- self.stats["total_tokens"] += len(inputs)
614
  else:
615
- logger.warning(f"Empty inputs after tokenization for example {paper_id}")
616
  self.stats["skipped"] += 1
617
 
618
  except Exception as e:
619
- logger.warning(f"Error processing conversations in example {paper_id}: {str(e)}")
620
  self.stats["skipped"] += 1
621
  continue
622
 
623
  except Exception as e:
624
- logger.warning(f"Error processing example: {str(e)[:100]}...")
625
- logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}")
626
  self.stats["skipped"] += 1
627
  continue
628
 
629
  if not batch["input_ids"]:
630
  logger.warning("Empty batch, returning dummy tensors")
631
  return {
632
- "input_ids": torch.zeros((1, 1), dtype=torch.long),
633
- "attention_mask": torch.zeros((1, 1), dtype=torch.long),
634
- "labels": torch.zeros((1, 1), dtype=torch.long)
635
  }
636
 
637
  # Pad the batch
@@ -642,17 +582,16 @@ class SimpleDataCollator:
642
  if padding_length > 0:
643
  batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
644
  batch["attention_mask"][i].extend([0] * padding_length)
645
- batch["labels"][i].extend([-100] * padding_length)
646
 
647
  # Convert to tensors
648
- batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
649
 
650
  # Log stats periodically
651
- log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
652
- if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0:
653
- logger.info(f"Data collator stats: processed={self.stats['processed']}, "
654
  f"skipped={self.stats['skipped']}, "
655
- f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}")
656
 
657
  return batch
658
 
 
494
  self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
495
  self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
496
  self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
497
+ logger.info(f"SimpleDataCollator initialized with max_seq_length={self.max_seq_length}")
 
 
 
498
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
499
+
 
500
  def __call__(self, features):
 
501
  batch = {"input_ids": [], "attention_mask": [], "labels": []}
502
 
503
  for example in features:
504
  try:
505
  # Get ID for logging
506
+ paper_id = example.get("article_id", "unknown")
507
+
508
+ # Get conversations - we expect a list with a single dict containing 'content'
509
+ conversations = example.get("conversations", [])
510
 
511
+ # Skip if conversations is None or empty
512
+ if not conversations:
513
+ logger.warning(f"Empty conversations for paper_id {paper_id}")
 
514
  self.stats["skipped"] += 1
515
  continue
516
 
517
+ # Get the first (and should be only) conversation item
518
+ conv_item = conversations[0] if conversations else None
519
+
520
+ # Skip if no valid conversation item
521
+ if not isinstance(conv_item, dict):
522
+ logger.warning(f"Invalid conversation format for paper_id {paper_id}")
523
  self.stats["skipped"] += 1
524
  continue
525
 
526
+ # Get the content directly
527
+ content = conv_item.get("content", "")
528
+
529
+ # Skip if no content
530
+ if not content:
531
+ logger.warning(f"Empty content for paper_id {paper_id}")
532
  self.stats["skipped"] += 1
533
  continue
534
 
535
+ # Tokenize the content directly
536
  try:
537
+ inputs = self.tokenizer(
538
+ content,
539
+ add_special_tokens=True,
540
+ return_tensors=None,
541
+ truncation=True,
542
+ max_length=self.max_seq_length
543
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
545
+ input_ids = inputs["input_ids"]
546
+ attention_mask = inputs["attention_mask"]
547
 
548
+ if len(input_ids) > 0:
549
+ batch["input_ids"].append(input_ids)
 
 
 
550
  batch["attention_mask"].append(attention_mask)
551
+ batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids
552
 
553
  self.stats["processed"] += 1
554
+ self.stats["total_tokens"] += len(input_ids)
555
  else:
556
+ logger.warning(f"Empty tokenization output for paper_id {paper_id}")
557
  self.stats["skipped"] += 1
558
 
559
  except Exception as e:
560
+ logger.warning(f"Tokenization failed for paper_id {paper_id}: {str(e)}")
561
  self.stats["skipped"] += 1
562
  continue
563
 
564
  except Exception as e:
565
+ logger.warning(f"Error processing example: {str(e)}")
 
566
  self.stats["skipped"] += 1
567
  continue
568
 
569
  if not batch["input_ids"]:
570
  logger.warning("Empty batch, returning dummy tensors")
571
  return {
572
+ "input_ids": torch.zeros((1, 1), dtype=torch.long, device=self.device),
573
+ "attention_mask": torch.zeros((1, 1), dtype=torch.long, device=self.device),
574
+ "labels": torch.zeros((1, 1), dtype=torch.long, device=self.device)
575
  }
576
 
577
  # Pad the batch
 
582
  if padding_length > 0:
583
  batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
584
  batch["attention_mask"][i].extend([0] * padding_length)
585
+ batch["labels"][i].extend([-100] * padding_length) # -100 is the ignore index for loss
586
 
587
  # Convert to tensors
588
+ batch = {k: torch.tensor(v, dtype=torch.long, device=self.device) for k, v in batch.items()}
589
 
590
  # Log stats periodically
591
+ if self.stats["processed"] % 100 == 0:
592
+ logger.info(f"Collator stats: processed={self.stats['processed']}, "
 
593
  f"skipped={self.stats['skipped']}, "
594
+ f"avg_tokens={self.stats['total_tokens']/max(1, self.stats['processed']):.1f}")
595
 
596
  return batch
597