Twelve2five commited on
Commit
14bbc11
·
verified ·
1 Parent(s): 192b89f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -60
app.py CHANGED
@@ -298,7 +298,7 @@ def load_model():
298
  log.append(f"LoRA rank: 8, alpha: 16 (optimized for 1B model)")
299
  model_to_train.print_trainable_parameters()
300
 
301
- return model, tokenizer # Return both model and tokenizer
302
 
303
  def load_dataset():
304
  # --- Download the dataset repository files ---
@@ -670,77 +670,108 @@ def train_model(
670
  sample_data = torch.load(sample_file)
671
  log.append(f"Sample data type: {type(sample_data)}")
672
 
673
- if isinstance(sample_data, dict):
674
- log.append(f"Sample data is a dictionary with keys: {list(sample_data.keys())}")
675
- # Print a few sample values to understand the structure
676
- for key in list(sample_data.keys())[:3]:
677
- log.append(f"Key '{key}' has value of type {type(sample_data[key])}")
678
- if isinstance(sample_data[key], torch.Tensor):
679
- log.append(f" Shape: {sample_data[key].shape}, Dtype: {sample_data[key].dtype}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680
 
681
- # Load all files with the appropriate structure
682
  input_ids_list = []
683
  labels_list = []
684
 
685
- for pt_file in tqdm(pt_files, desc="Loading .pt files"):
686
- data = torch.load(pt_file)
687
-
688
- # Handling dictionary structure
689
- if isinstance(data, dict):
690
- # Assume dictionary contains input_ids and labels keys
691
- if 'input_ids' in data and 'labels' in data:
692
- input_ids_list.append(data['input_ids'])
693
- labels_list.append(data['labels'])
694
- # Or maybe it has other keys that we need to convert
695
- elif 'prompt' in data and 'response' in data:
696
- input_ids_list.append(data['prompt'])
697
- labels_list.append(data['response'])
698
- # Or maybe it has source and target keys
699
- elif 'source' in data and 'target' in data:
700
- input_ids_list.append(data['source'])
701
- labels_list.append(data['target'])
702
- # If none of these patterns match, try to figure out the structure
703
  else:
704
- log.append(f"Unknown dictionary structure in {pt_file} with keys: {list(data.keys())}")
705
- # Try the first two keys as input/output
706
- keys = list(data.keys())
707
- if len(keys) >= 2:
708
- input_ids_list.append(data[keys[0]])
709
- labels_list.append(data[keys[1]])
710
- # Handling tuple/list structure - the original expected format
711
- elif isinstance(data, (tuple, list)) and len(data) >= 2:
712
- input_ids_list.append(data[0])
713
- labels_list.append(data[1])
714
- else:
715
- log.append(f"Unsupported data format in {pt_file}: {type(data)}")
716
 
717
- log.append(f"Processed {len(input_ids_list)} input/label pairs")
 
 
 
 
 
 
718
 
719
- # Process tensors to ensure they're the right format
720
- processed_inputs = []
721
- processed_labels = []
722
 
 
 
723
  for i, (inputs, labels) in enumerate(zip(input_ids_list, labels_list)):
724
- # Convert to tensor if not already
725
- if not isinstance(inputs, torch.Tensor):
726
- inputs = torch.tensor(inputs)
727
- if not isinstance(labels, torch.Tensor):
728
- labels = torch.tensor(labels)
729
-
730
- # Ensure they're integer tensors
731
- inputs = inputs.long()
732
- labels = labels.long()
733
-
734
- # Append to lists, converting to standard Python lists for the Dataset
735
- processed_inputs.append(inputs.tolist())
736
- processed_labels.append(labels.tolist())
737
 
738
- # Log some diagnostics for the first few pairs
739
- if i < 3:
740
- log.append(f"Pair {i}: Input shape: {inputs.shape}, Label shape: {labels.shape}")
 
 
 
 
 
 
 
 
 
 
741
 
742
  # Create the dataset
743
- log.append("Creating dataset from processed pairs...")
 
 
 
 
744
  dataset = Dataset.from_dict({
745
  "input_ids": processed_inputs,
746
  "labels": processed_labels
@@ -754,6 +785,7 @@ def train_model(
754
  log.append(f"Created dataset with {len(train_dataset)} training examples and {len(val_dataset)} validation examples")
755
 
756
  except Exception as e:
 
757
  error_msg = f"Error processing dataset: {str(e)}\n{traceback.format_exc()}"
758
  log.append(error_msg)
759
  return "\n".join(log)
 
298
  log.append(f"LoRA rank: 8, alpha: 16 (optimized for 1B model)")
299
  model_to_train.print_trainable_parameters()
300
 
301
+ return model, tokenizer
302
 
303
  def load_dataset():
304
  # --- Download the dataset repository files ---
 
670
  sample_data = torch.load(sample_file)
671
  log.append(f"Sample data type: {type(sample_data)}")
672
 
673
+ # Function to recursively explore the data structure
674
+ def explore_data(data, prefix=""):
675
+ if isinstance(data, (list, tuple)):
676
+ log.append(f"{prefix}List/Tuple with {len(data)} items")
677
+ if len(data) > 0:
678
+ explore_data(data[0], prefix + " [0]: ")
679
+ elif isinstance(data, dict):
680
+ log.append(f"{prefix}Dictionary with keys: {list(data.keys())}")
681
+ for key in list(data.keys())[:2]: # Look at first 2 keys
682
+ explore_data(data[key], prefix + f" ['{key}']: ")
683
+ elif isinstance(data, torch.Tensor):
684
+ log.append(f"{prefix}Tensor with shape {data.shape} and dtype {data.dtype}")
685
+ else:
686
+ log.append(f"{prefix}Other type: {type(data)}")
687
+
688
+ # Explore the sample data
689
+ explore_data(sample_data, "Sample data: ")
690
+
691
+ # Function to extract tensor data from complex structures
692
+ def extract_tensor_data(data):
693
+ if isinstance(data, torch.Tensor):
694
+ return data
695
+ elif isinstance(data, (list, tuple)) and len(data) > 0:
696
+ if all(isinstance(item, (int, float)) for item in data):
697
+ return torch.tensor(data)
698
+ # For lists of tensors/complex structures, use the first item
699
+ return extract_tensor_data(data[0])
700
+ elif isinstance(data, dict):
701
+ # Try common keys for input data
702
+ for key in ['input_ids', 'prompt', 'source', 'inputs', 'data']:
703
+ if key in data:
704
+ return extract_tensor_data(data[key])
705
+ # If none found, use the first key
706
+ if len(data) > 0:
707
+ return extract_tensor_data(next(iter(data.values())))
708
+ return None
709
 
710
+ # Process all files
711
  input_ids_list = []
712
  labels_list = []
713
 
714
+ # Capture any errors for later analysis
715
+ file_errors = []
716
+
717
+ for i, pt_file in enumerate(tqdm(pt_files, desc="Loading .pt files")):
718
+ try:
719
+ data = torch.load(pt_file)
720
+
721
+ if isinstance(data, (list, tuple)) and len(data) >= 2:
722
+ # Standard format: list/tuple with [input, label]
723
+ input_tensor = extract_tensor_data(data[0])
724
+ label_tensor = extract_tensor_data(data[1])
725
+
726
+ if input_tensor is not None and label_tensor is not None:
727
+ input_ids_list.append(input_tensor)
728
+ labels_list.append(label_tensor)
729
+ else:
730
+ file_errors.append(f"Could not extract tensors from {pt_file}")
 
731
  else:
732
+ log.append(f"File {pt_file} has unexpected format. Skipping.")
733
+ file_errors.append(f"Unexpected format in {pt_file}: {type(data)}")
734
+ except Exception as e:
735
+ file_errors.append(f"Error processing file {pt_file}: {str(e)}")
 
 
 
 
 
 
 
 
736
 
737
+ # Log errors if any
738
+ if file_errors:
739
+ log.append(f"Encountered {len(file_errors)} errors during file processing:")
740
+ for i, error in enumerate(file_errors[:5]): # Log first 5 errors
741
+ log.append(f" Error {i+1}: {error}")
742
+ if len(file_errors) > 5:
743
+ log.append(f" ...and {len(file_errors) - 5} more errors")
744
 
745
+ log.append(f"Successfully processed {len(input_ids_list)} input/label pairs")
 
 
746
 
747
+ # Verify all tensors are valid
748
+ valid_pairs = []
749
  for i, (inputs, labels) in enumerate(zip(input_ids_list, labels_list)):
750
+ # Perform safety checks on tensors
751
+ if not isinstance(inputs, torch.Tensor) or not isinstance(labels, torch.Tensor):
752
+ log.append(f"Pair {i}: Invalid tensor types - skipping")
753
+ continue
 
 
 
 
 
 
 
 
 
754
 
755
+ # Ensure tensors contain integers
756
+ try:
757
+ inputs = inputs.long()
758
+ labels = labels.long()
759
+
760
+ # Convert to lists and add to valid pairs
761
+ valid_pairs.append((inputs.tolist(), labels.tolist()))
762
+
763
+ # Log some diagnostics for the first few pairs
764
+ if i < 3:
765
+ log.append(f"Pair {i}: Input shape: {inputs.shape}, Label shape: {labels.shape}")
766
+ except Exception as e:
767
+ log.append(f"Error converting tensors for pair {i}: {str(e)}")
768
 
769
  # Create the dataset
770
+ log.append(f"Creating dataset from {len(valid_pairs)} valid pairs...")
771
+
772
+ processed_inputs = [pair[0] for pair in valid_pairs]
773
+ processed_labels = [pair[1] for pair in valid_pairs]
774
+
775
  dataset = Dataset.from_dict({
776
  "input_ids": processed_inputs,
777
  "labels": processed_labels
 
785
  log.append(f"Created dataset with {len(train_dataset)} training examples and {len(val_dataset)} validation examples")
786
 
787
  except Exception as e:
788
+ import traceback
789
  error_msg = f"Error processing dataset: {str(e)}\n{traceback.format_exc()}"
790
  log.append(error_msg)
791
  return "\n".join(log)