Twelve2five commited on
Commit
192b89f
·
verified ·
1 Parent(s): 19ba848

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -80
app.py CHANGED
@@ -24,6 +24,7 @@ import subprocess
24
  import sys
25
  import json
26
  import shutil
 
27
 
28
  # --- Configuration ---
29
  YOUR_HF_USERNAME = "Twelve2five"
@@ -653,95 +654,108 @@ def train_model(
653
  downloaded_files = glob.glob(f"{local_dataset_path}/**/*.pt", recursive=True)
654
  log.append(f"Found {len(downloaded_files)} .pt files in the dataset directory")
655
 
656
- if len(downloaded_files) == 0:
657
- log.append("No .pt files found. Checking for other file types...")
658
- all_files = glob.glob(f"{local_dataset_path}/**/*.*", recursive=True)
659
- log.append(f"All files found: {', '.join(all_files[:10])}")
660
- if len(all_files) > 10:
661
- log.append(f"...and {len(all_files) - 10} more files")
662
-
663
- # Look for the pairs directory
664
  pairs_dir = os.path.join(local_dataset_path, "final_rvq_pairs")
665
- if not os.path.exists(pairs_dir):
666
- log.append(f"final_rvq_pairs directory not found. Looking for other possible directories...")
667
- possible_dirs = [d for d in glob.glob(f"{local_dataset_path}/**/") if os.path.isdir(d)]
668
- log.append(f"Available directories: {', '.join(possible_dirs)}")
669
-
670
- # Try to find any directory containing .pt files
671
- for dir_path in possible_dirs:
672
- if glob.glob(f"{dir_path}/*.pt"):
673
- pairs_dir = dir_path
674
- log.append(f"Using {pairs_dir} as the pairs directory.")
675
- break
676
 
677
- # If we found the pairs directory, we're good to go
678
- if pairs_dir and os.path.exists(pairs_dir):
679
- log.append(f"Using pairs directory: {pairs_dir}")
680
- pt_files = glob.glob(f"{pairs_dir}/*.pt")
681
- log.append(f"Found {len(pt_files)} .pt files in pairs directory")
682
 
683
- # Load the dataset from the files
684
- progress(0.5, desc="Loading pairs from dataset files...")
685
- log.append("Loading dataset pairs...")
 
 
 
 
686
 
687
- try:
688
- # Load pairs from .pt files
689
- pairs = []
690
- for pt_file in tqdm(pt_files, desc="Loading .pt files"):
691
- pair_data = torch.load(pt_file)
692
- pairs.append(pair_data)
693
-
694
- log.append(f"Loaded {len(pairs)} conversation pairs")
695
-
696
- # Create a dataset from the pairs
697
- dataset = Dataset.from_dict({
698
- "input_ids": [pair[0].tolist() for pair in pairs],
699
- "labels": [pair[1].tolist() for pair in pairs]
700
- })
701
-
702
- # Split into training and validation sets
703
- train_test_split = dataset.train_test_split(test_size=0.05)
704
- train_dataset = train_test_split["train"]
705
-
706
- log.append(f"Created dataset with {len(train_dataset)} training examples")
707
 
708
- except Exception as e:
709
- log.append(f"Error loading pair data: {e}")
710
-
711
- # Try an alternative approach - look for JSON or other formats
712
- log.append("Attempting alternative dataset loading approaches...")
713
 
714
- # Search for JSON files
715
- json_files = glob.glob(f"{local_dataset_path}/**/*.json", recursive=True)
716
- if json_files:
717
- log.append(f"Found {len(json_files)} JSON files. Trying to load from these...")
718
-
719
- # Load from JSON
720
- combined_data = []
721
- for json_file in json_files[:5]: # Start with a few files
722
- try:
723
- with open(json_file, 'r') as f:
724
- file_data = json.load(f)
725
- log.append(f"Successfully loaded {json_file}")
726
- # Print sample of the data structure
727
- log.append(f"Sample data structure: {str(file_data)[:500]}...")
728
- combined_data.append(file_data)
729
- except Exception as je:
730
- log.append(f"Error loading {json_file}: {je}")
731
-
732
- # If we loaded any data, try to create a dataset from it
733
- if combined_data:
734
- log.append("Attempting to create dataset from JSON data...")
735
- # This will need adapting based on the actual JSON structure
 
 
 
 
736
  else:
737
- log.append("No JSON files found. Looking for other formats...")
738
- # Add code for other formats if needed
 
 
 
 
 
 
 
 
 
 
 
 
739
 
740
- log.append("Failed to load dataset after multiple attempts.")
741
- return "\n".join(log)
 
742
 
743
- else:
744
- log.append("Could not locate pairs directory or any directory with .pt files.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
  return "\n".join(log)
746
 
747
  except Exception as e:
 
24
  import sys
25
  import json
26
  import shutil
27
+ import traceback
28
 
29
  # --- Configuration ---
30
  YOUR_HF_USERNAME = "Twelve2five"
 
654
  downloaded_files = glob.glob(f"{local_dataset_path}/**/*.pt", recursive=True)
655
  log.append(f"Found {len(downloaded_files)} .pt files in the dataset directory")
656
 
657
+ # Look for the pairs directory (we know this exists from the log)
 
 
 
 
 
 
 
658
  pairs_dir = os.path.join(local_dataset_path, "final_rvq_pairs")
659
+ log.append(f"Using pairs directory: {pairs_dir}")
660
+ pt_files = glob.glob(f"{pairs_dir}/*.pt")
661
+ log.append(f"Found {len(pt_files)} .pt files in pairs directory")
662
+
663
+ # Load the dataset from the files
664
+ progress(0.5, desc="Loading pairs from dataset files...")
665
+ log.append("Loading dataset pairs...")
 
 
 
 
666
 
667
+ try:
668
+ # Load a single file first to understand its structure
669
+ sample_file = pt_files[0]
670
+ sample_data = torch.load(sample_file)
671
+ log.append(f"Sample data type: {type(sample_data)}")
672
 
673
+ if isinstance(sample_data, dict):
674
+ log.append(f"Sample data is a dictionary with keys: {list(sample_data.keys())}")
675
+ # Print a few sample values to understand the structure
676
+ for key in list(sample_data.keys())[:3]:
677
+ log.append(f"Key '{key}' has value of type {type(sample_data[key])}")
678
+ if isinstance(sample_data[key], torch.Tensor):
679
+ log.append(f" Shape: {sample_data[key].shape}, Dtype: {sample_data[key].dtype}")
680
 
681
+ # Load all files with the appropriate structure
682
+ input_ids_list = []
683
+ labels_list = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
 
685
+ for pt_file in tqdm(pt_files, desc="Loading .pt files"):
686
+ data = torch.load(pt_file)
 
 
 
687
 
688
+ # Handling dictionary structure
689
+ if isinstance(data, dict):
690
+ # Assume dictionary contains input_ids and labels keys
691
+ if 'input_ids' in data and 'labels' in data:
692
+ input_ids_list.append(data['input_ids'])
693
+ labels_list.append(data['labels'])
694
+ # Or maybe it has other keys that we need to convert
695
+ elif 'prompt' in data and 'response' in data:
696
+ input_ids_list.append(data['prompt'])
697
+ labels_list.append(data['response'])
698
+ # Or maybe it has source and target keys
699
+ elif 'source' in data and 'target' in data:
700
+ input_ids_list.append(data['source'])
701
+ labels_list.append(data['target'])
702
+ # If none of these patterns match, try to figure out the structure
703
+ else:
704
+ log.append(f"Unknown dictionary structure in {pt_file} with keys: {list(data.keys())}")
705
+ # Try the first two keys as input/output
706
+ keys = list(data.keys())
707
+ if len(keys) >= 2:
708
+ input_ids_list.append(data[keys[0]])
709
+ labels_list.append(data[keys[1]])
710
+ # Handling tuple/list structure - the original expected format
711
+ elif isinstance(data, (tuple, list)) and len(data) >= 2:
712
+ input_ids_list.append(data[0])
713
+ labels_list.append(data[1])
714
  else:
715
+ log.append(f"Unsupported data format in {pt_file}: {type(data)}")
716
+
717
+ log.append(f"Processed {len(input_ids_list)} input/label pairs")
718
+
719
+ # Process tensors to ensure they're the right format
720
+ processed_inputs = []
721
+ processed_labels = []
722
+
723
+ for i, (inputs, labels) in enumerate(zip(input_ids_list, labels_list)):
724
+ # Convert to tensor if not already
725
+ if not isinstance(inputs, torch.Tensor):
726
+ inputs = torch.tensor(inputs)
727
+ if not isinstance(labels, torch.Tensor):
728
+ labels = torch.tensor(labels)
729
 
730
+ # Ensure they're integer tensors
731
+ inputs = inputs.long()
732
+ labels = labels.long()
733
 
734
+ # Append to lists, converting to standard Python lists for the Dataset
735
+ processed_inputs.append(inputs.tolist())
736
+ processed_labels.append(labels.tolist())
737
+
738
+ # Log some diagnostics for the first few pairs
739
+ if i < 3:
740
+ log.append(f"Pair {i}: Input shape: {inputs.shape}, Label shape: {labels.shape}")
741
+
742
+ # Create the dataset
743
+ log.append("Creating dataset from processed pairs...")
744
+ dataset = Dataset.from_dict({
745
+ "input_ids": processed_inputs,
746
+ "labels": processed_labels
747
+ })
748
+
749
+ # Split into training and validation
750
+ train_test_split = dataset.train_test_split(test_size=0.05)
751
+ train_dataset = train_test_split["train"]
752
+ val_dataset = train_test_split["test"]
753
+
754
+ log.append(f"Created dataset with {len(train_dataset)} training examples and {len(val_dataset)} validation examples")
755
+
756
+ except Exception as e:
757
+ error_msg = f"Error processing dataset: {str(e)}\n{traceback.format_exc()}"
758
+ log.append(error_msg)
759
  return "\n".join(log)
760
 
761
  except Exception as e: