George-API commited on
Commit
ae4e1de
·
verified ·
1 Parent(s): b033a7b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. run_transformers_training.py +35 -43
run_transformers_training.py CHANGED
@@ -494,7 +494,7 @@ class SimpleDataCollator:
494
  self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
495
  self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
496
  self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
497
- logger.info(f"SimpleDataCollator initialized with max_seq_length={self.max_seq_length}")
498
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
499
 
500
  def __call__(self, features):
@@ -504,65 +504,57 @@ class SimpleDataCollator:
504
  try:
505
  # Get ID for logging
506
  paper_id = example.get("article_id", "unknown")
 
507
 
508
- # Get conversations - we expect a list with a single dict containing 'content'
509
  conversations = example.get("conversations", [])
510
 
511
- # Skip if conversations is None or empty
512
  if not conversations:
513
- logger.warning(f"Empty conversations for paper_id {paper_id}")
514
  self.stats["skipped"] += 1
515
  continue
516
 
517
- # Get the first (and should be only) conversation item
518
- conv_item = conversations[0] if conversations else None
519
 
520
- # Skip if no valid conversation item
521
- if not isinstance(conv_item, dict):
522
- logger.warning(f"Invalid conversation format for paper_id {paper_id}")
523
  self.stats["skipped"] += 1
524
  continue
525
 
526
- # Get the content directly
527
- content = conv_item.get("content", "")
528
 
529
- # Skip if no content
530
  if not content:
531
- logger.warning(f"Empty content for paper_id {paper_id}")
532
  self.stats["skipped"] += 1
533
  continue
534
 
535
- # Tokenize the content directly
536
- try:
537
- inputs = self.tokenizer(
538
- content,
539
- add_special_tokens=True,
540
- return_tensors=None,
541
- truncation=True,
542
- max_length=self.max_seq_length
543
- )
544
-
545
- input_ids = inputs["input_ids"]
546
- attention_mask = inputs["attention_mask"]
547
-
548
- if len(input_ids) > 0:
549
- batch["input_ids"].append(input_ids)
550
- batch["attention_mask"].append(attention_mask)
551
- batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids
552
-
553
- self.stats["processed"] += 1
554
- self.stats["total_tokens"] += len(input_ids)
555
- else:
556
- logger.warning(f"Empty tokenization output for paper_id {paper_id}")
557
- self.stats["skipped"] += 1
558
-
559
- except Exception as e:
560
- logger.warning(f"Tokenization failed for paper_id {paper_id}: {str(e)}")
561
- self.stats["skipped"] += 1
562
- continue
563
-
564
  except Exception as e:
565
- logger.warning(f"Error processing example: {str(e)}")
566
  self.stats["skipped"] += 1
567
  continue
568
 
 
494
  self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
495
  self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
496
  self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
497
+ logger.info(f"SimpleDataCollator initialized - using pre-tokenized chunks with max_seq_length={self.max_seq_length}")
498
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
499
 
500
  def __call__(self, features):
 
504
  try:
505
  # Get ID for logging
506
  paper_id = example.get("article_id", "unknown")
507
+ prompt_num = example.get("prompt_number", "unknown")
508
 
509
+ # Get the conversations list - should be a single item
510
  conversations = example.get("conversations", [])
511
 
512
+ # Skip if no conversations
513
  if not conversations:
514
+ logger.warning(f"Empty conversations for paper_id {paper_id}, prompt {prompt_num}")
515
  self.stats["skipped"] += 1
516
  continue
517
 
518
+ # Get the first conversation item (should be the only one)
519
+ conv_item = conversations[0]
520
 
521
+ # Skip if invalid format
522
+ if not isinstance(conv_item, dict) or "content" not in conv_item:
523
+ logger.warning(f"Invalid conversation format for paper_id {paper_id}, prompt {prompt_num}")
524
  self.stats["skipped"] += 1
525
  continue
526
 
527
+ # Get the pre-tokenized content
528
+ content = conv_item["content"]
529
 
530
+ # Skip if empty content
531
  if not content:
532
+ logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}")
533
  self.stats["skipped"] += 1
534
  continue
535
 
536
+ # Create input IDs and attention mask directly from the content
537
+ # The content is already pre-tokenized and properly chunked
538
+ input_ids = self.tokenizer.encode(content, add_special_tokens=False)
539
+
540
+ # Truncate if needed
541
+ if len(input_ids) > self.max_seq_length:
542
+ input_ids = input_ids[:self.max_seq_length]
543
+ logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}")
544
+
545
+ # Create attention mask (1s for all tokens)
546
+ attention_mask = [1] * len(input_ids)
547
+
548
+ # Add to batch
549
+ batch["input_ids"].append(input_ids)
550
+ batch["attention_mask"].append(attention_mask)
551
+ batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids
552
+
553
+ self.stats["processed"] += 1
554
+ self.stats["total_tokens"] += len(input_ids)
555
+
 
 
 
 
 
 
 
 
 
556
  except Exception as e:
557
+ logger.warning(f"Error processing example {paper_id}, prompt {prompt_num}: {str(e)}")
558
  self.stats["skipped"] += 1
559
  continue
560