Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -24,6 +24,7 @@ import subprocess
|
|
24 |
import sys
|
25 |
import json
|
26 |
import shutil
|
|
|
27 |
|
28 |
# --- Configuration ---
|
29 |
YOUR_HF_USERNAME = "Twelve2five"
|
@@ -653,95 +654,108 @@ def train_model(
|
|
653 |
downloaded_files = glob.glob(f"{local_dataset_path}/**/*.pt", recursive=True)
|
654 |
log.append(f"Found {len(downloaded_files)} .pt files in the dataset directory")
|
655 |
|
656 |
-
|
657 |
-
log.append("No .pt files found. Checking for other file types...")
|
658 |
-
all_files = glob.glob(f"{local_dataset_path}/**/*.*", recursive=True)
|
659 |
-
log.append(f"All files found: {', '.join(all_files[:10])}")
|
660 |
-
if len(all_files) > 10:
|
661 |
-
log.append(f"...and {len(all_files) - 10} more files")
|
662 |
-
|
663 |
-
# Look for the pairs directory
|
664 |
pairs_dir = os.path.join(local_dataset_path, "final_rvq_pairs")
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
if glob.glob(f"{dir_path}/*.pt"):
|
673 |
-
pairs_dir = dir_path
|
674 |
-
log.append(f"Using {pairs_dir} as the pairs directory.")
|
675 |
-
break
|
676 |
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
log.append(f"
|
682 |
|
683 |
-
|
684 |
-
|
685 |
-
|
|
|
|
|
|
|
|
|
686 |
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
for pt_file in tqdm(pt_files, desc="Loading .pt files"):
|
691 |
-
pair_data = torch.load(pt_file)
|
692 |
-
pairs.append(pair_data)
|
693 |
-
|
694 |
-
log.append(f"Loaded {len(pairs)} conversation pairs")
|
695 |
-
|
696 |
-
# Create a dataset from the pairs
|
697 |
-
dataset = Dataset.from_dict({
|
698 |
-
"input_ids": [pair[0].tolist() for pair in pairs],
|
699 |
-
"labels": [pair[1].tolist() for pair in pairs]
|
700 |
-
})
|
701 |
-
|
702 |
-
# Split into training and validation sets
|
703 |
-
train_test_split = dataset.train_test_split(test_size=0.05)
|
704 |
-
train_dataset = train_test_split["train"]
|
705 |
-
|
706 |
-
log.append(f"Created dataset with {len(train_dataset)} training examples")
|
707 |
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
# Try an alternative approach - look for JSON or other formats
|
712 |
-
log.append("Attempting alternative dataset loading approaches...")
|
713 |
|
714 |
-
#
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
|
|
|
|
|
|
|
|
736 |
else:
|
737 |
-
log.append("
|
738 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
739 |
|
740 |
-
|
741 |
-
|
|
|
742 |
|
743 |
-
|
744 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
745 |
return "\n".join(log)
|
746 |
|
747 |
except Exception as e:
|
|
|
24 |
import sys
|
25 |
import json
|
26 |
import shutil
|
27 |
+
import traceback
|
28 |
|
29 |
# --- Configuration ---
|
30 |
YOUR_HF_USERNAME = "Twelve2five"
|
|
|
654 |
downloaded_files = glob.glob(f"{local_dataset_path}/**/*.pt", recursive=True)
|
655 |
log.append(f"Found {len(downloaded_files)} .pt files in the dataset directory")
|
656 |
|
657 |
+
# Look for the pairs directory (we know this exists from the log)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
658 |
pairs_dir = os.path.join(local_dataset_path, "final_rvq_pairs")
|
659 |
+
log.append(f"Using pairs directory: {pairs_dir}")
|
660 |
+
pt_files = glob.glob(f"{pairs_dir}/*.pt")
|
661 |
+
log.append(f"Found {len(pt_files)} .pt files in pairs directory")
|
662 |
+
|
663 |
+
# Load the dataset from the files
|
664 |
+
progress(0.5, desc="Loading pairs from dataset files...")
|
665 |
+
log.append("Loading dataset pairs...")
|
|
|
|
|
|
|
|
|
666 |
|
667 |
+
try:
|
668 |
+
# Load a single file first to understand its structure
|
669 |
+
sample_file = pt_files[0]
|
670 |
+
sample_data = torch.load(sample_file)
|
671 |
+
log.append(f"Sample data type: {type(sample_data)}")
|
672 |
|
673 |
+
if isinstance(sample_data, dict):
|
674 |
+
log.append(f"Sample data is a dictionary with keys: {list(sample_data.keys())}")
|
675 |
+
# Print a few sample values to understand the structure
|
676 |
+
for key in list(sample_data.keys())[:3]:
|
677 |
+
log.append(f"Key '{key}' has value of type {type(sample_data[key])}")
|
678 |
+
if isinstance(sample_data[key], torch.Tensor):
|
679 |
+
log.append(f" Shape: {sample_data[key].shape}, Dtype: {sample_data[key].dtype}")
|
680 |
|
681 |
+
# Load all files with the appropriate structure
|
682 |
+
input_ids_list = []
|
683 |
+
labels_list = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
|
685 |
+
for pt_file in tqdm(pt_files, desc="Loading .pt files"):
|
686 |
+
data = torch.load(pt_file)
|
|
|
|
|
|
|
687 |
|
688 |
+
# Handling dictionary structure
|
689 |
+
if isinstance(data, dict):
|
690 |
+
# Assume dictionary contains input_ids and labels keys
|
691 |
+
if 'input_ids' in data and 'labels' in data:
|
692 |
+
input_ids_list.append(data['input_ids'])
|
693 |
+
labels_list.append(data['labels'])
|
694 |
+
# Or maybe it has other keys that we need to convert
|
695 |
+
elif 'prompt' in data and 'response' in data:
|
696 |
+
input_ids_list.append(data['prompt'])
|
697 |
+
labels_list.append(data['response'])
|
698 |
+
# Or maybe it has source and target keys
|
699 |
+
elif 'source' in data and 'target' in data:
|
700 |
+
input_ids_list.append(data['source'])
|
701 |
+
labels_list.append(data['target'])
|
702 |
+
# If none of these patterns match, try to figure out the structure
|
703 |
+
else:
|
704 |
+
log.append(f"Unknown dictionary structure in {pt_file} with keys: {list(data.keys())}")
|
705 |
+
# Try the first two keys as input/output
|
706 |
+
keys = list(data.keys())
|
707 |
+
if len(keys) >= 2:
|
708 |
+
input_ids_list.append(data[keys[0]])
|
709 |
+
labels_list.append(data[keys[1]])
|
710 |
+
# Handling tuple/list structure - the original expected format
|
711 |
+
elif isinstance(data, (tuple, list)) and len(data) >= 2:
|
712 |
+
input_ids_list.append(data[0])
|
713 |
+
labels_list.append(data[1])
|
714 |
else:
|
715 |
+
log.append(f"Unsupported data format in {pt_file}: {type(data)}")
|
716 |
+
|
717 |
+
log.append(f"Processed {len(input_ids_list)} input/label pairs")
|
718 |
+
|
719 |
+
# Process tensors to ensure they're the right format
|
720 |
+
processed_inputs = []
|
721 |
+
processed_labels = []
|
722 |
+
|
723 |
+
for i, (inputs, labels) in enumerate(zip(input_ids_list, labels_list)):
|
724 |
+
# Convert to tensor if not already
|
725 |
+
if not isinstance(inputs, torch.Tensor):
|
726 |
+
inputs = torch.tensor(inputs)
|
727 |
+
if not isinstance(labels, torch.Tensor):
|
728 |
+
labels = torch.tensor(labels)
|
729 |
|
730 |
+
# Ensure they're integer tensors
|
731 |
+
inputs = inputs.long()
|
732 |
+
labels = labels.long()
|
733 |
|
734 |
+
# Append to lists, converting to standard Python lists for the Dataset
|
735 |
+
processed_inputs.append(inputs.tolist())
|
736 |
+
processed_labels.append(labels.tolist())
|
737 |
+
|
738 |
+
# Log some diagnostics for the first few pairs
|
739 |
+
if i < 3:
|
740 |
+
log.append(f"Pair {i}: Input shape: {inputs.shape}, Label shape: {labels.shape}")
|
741 |
+
|
742 |
+
# Create the dataset
|
743 |
+
log.append("Creating dataset from processed pairs...")
|
744 |
+
dataset = Dataset.from_dict({
|
745 |
+
"input_ids": processed_inputs,
|
746 |
+
"labels": processed_labels
|
747 |
+
})
|
748 |
+
|
749 |
+
# Split into training and validation
|
750 |
+
train_test_split = dataset.train_test_split(test_size=0.05)
|
751 |
+
train_dataset = train_test_split["train"]
|
752 |
+
val_dataset = train_test_split["test"]
|
753 |
+
|
754 |
+
log.append(f"Created dataset with {len(train_dataset)} training examples and {len(val_dataset)} validation examples")
|
755 |
+
|
756 |
+
except Exception as e:
|
757 |
+
error_msg = f"Error processing dataset: {str(e)}\n{traceback.format_exc()}"
|
758 |
+
log.append(error_msg)
|
759 |
return "\n".join(log)
|
760 |
|
761 |
except Exception as e:
|