import os import torch import glob import gc from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, Trainer, DataCollatorForLanguageModeling, AutoTokenizer, LlamaConfig ) from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training from datasets import Dataset from huggingface_hub import snapshot_download from tqdm import tqdm import gradio as gr import math from accelerate import Accelerator import subprocess import sys import json # --- Configuration --- YOUR_HF_USERNAME = "Twelve2five" MODEL_REPO_NAME = "llama-3-8b-rvq-resized" DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items" hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}" hf_dataset_repo_id = f"{YOUR_HF_USERNAME}/{DATASET_REPO_NAME}" # Output directories OUTPUT_TRAINING_DIR = "./llama3-8b-rvq-qlora-finetuned-run" LOGGING_DIR = "./llama3-8b-rvq-qlora-logs-run" local_download_path = "./downloaded_dataset_files" # Training parameters NUM_EPOCHS = 1 BATCH_SIZE_PER_DEVICE = 1 GRAD_ACCUMULATION_STEPS = 64 LEARNING_RATE = 1e-4 WEIGHT_DECAY = 0.01 WARMUP_RATIO = 0.03 LR_SCHEDULER = "cosine" OPTIMIZER = "paged_adamw_8bit" MAX_SEQ_LENGTH = 256 MICRO_BATCH_SIZE = 1 # Multi-GPU configuration accelerator = Accelerator() # Configure environment for multi-GPU os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32" # Print GPU information print(f"Available GPUs: {torch.cuda.device_count()}") for i in range(torch.cuda.device_count()): print(f"GPU {i}: {torch.cuda.get_device_name(i)} with {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB") def seq2seq_causal_collator(features): """ Collator that concatenates context (input_ids) and target (labels) for Causal LM sequence-to-sequence training. Masks the loss for the context part of the sequence. Pads sequences to the maximum length in the batch. """ batch = {} concatenated_input_ids = [] concatenated_labels = [] max_len = 0 # --- First pass: Concatenate, create masked labels, find max length --- for feature in features: # Dataset transform should provide tensors here input_ids = feature['input_ids'] labels = feature['labels'] # Ensure tensors are 1D (handle potential extra dims if any) if input_ids.dim() > 1: input_ids = input_ids.squeeze() if labels.dim() > 1: labels = labels.squeeze() context_len = input_ids.shape[0] target_len = labels.shape[0] # Concatenate context and target for input combined_ids = torch.cat([input_ids, labels], dim=0) concatenated_input_ids.append(combined_ids) # Create labels: -100 for context, actual labels for target masked_labels = torch.cat([ torch.full((context_len,), -100, dtype=torch.long, device=input_ids.device), labels ], dim=0) concatenated_labels.append(masked_labels) # Track max length for padding if combined_ids.shape[0] > max_len: max_len = combined_ids.shape[0] # --- Second pass: Pad to max length --- padded_input_ids = [] padded_labels = [] input_pad_token_id = 0 label_pad_token_id = -100 for i in range(len(features)): ids = concatenated_input_ids[i] lbls = concatenated_labels[i] padding_len = max_len - ids.shape[0] # Pad on the right side padded_input_ids.append(torch.nn.functional.pad( ids, (0, padding_len), value=input_pad_token_id )) padded_labels.append(torch.nn.functional.pad( lbls, (0, padding_len), value=label_pad_token_id )) # --- Stack and create final batch --- batch['input_ids'] = torch.stack(padded_input_ids) batch['labels'] = torch.stack(padded_labels) # Create attention mask (1 for real tokens, 0 for padding) batch['attention_mask'] = batch['input_ids'].ne(input_pad_token_id).long() return batch def prepare_for_dataset(batch): output = {'input_ids': [], 'labels': []} for item in batch: output['input_ids'].append(item['input_ids'].cpu().tolist()) output['labels'].append(item['labels'].cpu().tolist()) return output def load_model(): print(f"Loading base model architecture from: {hf_model_repo_id}") # Get information about GPU with most free memory gpu_id = 0 # Default to first GPU max_free_memory = 0 for i in range(torch.cuda.device_count()): free_memory = torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i) if free_memory > max_free_memory: max_free_memory = free_memory gpu_id = i print(f"Loading model on GPU {gpu_id} with {max_free_memory / 1e9:.2f}GB free memory") # Configure quantization bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) # Load the model try: # First update transformers to make sure we have latest version subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"]) # Now try loading with explicit config class to avoid auto-detection issues from transformers import LlamaConfig # Load config first config = LlamaConfig.from_pretrained( hf_model_repo_id, trust_remote_code=True ) # Then load model with explicit config model = AutoModelForCausalLM.from_pretrained( hf_model_repo_id, config=config, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) log.append(f"Loaded model vocab size: {model.config.vocab_size}") log.append(f"Input embedding shape: {model.get_input_embeddings().weight.shape}") except Exception as e: error_msg = f"Error loading model from Hub: {e}" log.append(error_msg) # Try with a fallback method try: log.append("Attempting alternative loading method...") # Try loading without auto detection model = AutoModelForCausalLM.from_pretrained( hf_model_repo_id, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16, # Add these to help with the loading revision="main", low_cpu_mem_usage=True, ) log.append("Alternative loading successful!") log.append(f"Loaded model vocab size: {model.config.vocab_size}") except Exception as e2: log.append(f"Alternative loading also failed: {e2}") return "\n".join(log) # Load the official Meta tokenizer for LLaMA 3 tokenizer = AutoTokenizer.from_pretrained( "meta-llama/Llama-3-8B", # Use the official Meta tokenizer use_auth_token=os.environ.get("HF_TOKEN", None) # In case it's needed ) if tokenizer is None: # Fallback to another common foundation model tokenizer print("Falling back to another tokenizer as Meta tokenizer requires auth token") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") print(f"Loaded tokenizer vocabulary size: {len(tokenizer)}") # Print information about input embeddings print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}") # Prepare model for k-bit training model = prepare_model_for_kbit_training(model) # Define LoRA configuration lora_config = LoraConfig( r=16, lora_alpha=32, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM ) # Apply LoRA to model model = get_peft_model(model, lora_config) model.print_trainable_parameters() return model, tokenizer # Return both model and tokenizer def load_dataset(): # --- Download the dataset repository files --- try: os.makedirs(local_download_path, exist_ok=True) downloaded_repo_root = snapshot_download( repo_id=hf_dataset_repo_id, repo_type="dataset", local_dir=local_download_path, local_dir_use_symlinks=False ) print(f"Dataset repository content downloaded to: {downloaded_repo_root}") except Exception as e: print(f"Error downloading dataset: {e}") return None # --- Load .pt files into a Hugging Face Dataset object --- pairs_dir = os.path.join(downloaded_repo_root, "final_rvq_pairs") all_pair_files = glob.glob(os.path.join(pairs_dir, "*_rvq_pairs.pt")) if not all_pair_files: all_pair_files = glob.glob(os.path.join(downloaded_repo_root, "*_rvq_pairs.pt")) if not all_pair_files: print("No RVQ pair files found!") return None print(f"Found {len(all_pair_files)} RVQ pair files.") # Load data from .pt files into memory all_data_pairs = [] for file_path in tqdm(all_pair_files, desc="Loading pair files"): try: episode_pairs = torch.load(file_path, map_location='cpu') all_data_pairs.extend(episode_pairs) except Exception as e: print(f"Warning: Could not load file {file_path}: {e}") if not all_data_pairs: return None print(f"Loaded {len(all_data_pairs)} training pairs.") # Convert to Hugging Face Dataset chunk_size = 1000 processed_data = {'input_ids': [], 'labels': []} for i in tqdm(range(0, len(all_data_pairs), chunk_size), desc="Preparing data"): batch = all_data_pairs[i:i + chunk_size] prepared_batch = prepare_for_dataset(batch) processed_data['input_ids'].extend(prepared_batch['input_ids']) processed_data['labels'].extend(prepared_batch['labels']) hf_dataset = Dataset.from_dict(processed_data) # Transform to get tensors back hf_dataset.set_transform(lambda batch: { 'input_ids': [torch.tensor(ids, dtype=torch.long) for ids in batch['input_ids']], 'labels': [torch.tensor(lbls, dtype=torch.long) for lbls in batch['labels']] }) # Cleanup del all_data_pairs del processed_data gc.collect() return hf_dataset # Memory cleaning function def clean_memory(): gc.collect() if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): with torch.cuda.device(f'cuda:{i}'): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() def train_model( hf_username, model_repo_name, dataset_repo_name, epochs=1, batch_size=1, grad_accum_steps=4, learning_rate=1e-4, progress=gr.Progress() ): progress(0, desc="Installing dependencies...") # Install required packages if needed try: import transformers import accelerate import bitsandbytes import peft import deepspeed except ImportError: subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-U", "transformers", "accelerate", "bitsandbytes", "peft", "torch", "datasets", "huggingface_hub", "deepspeed"]) # --- Configuration --- progress(0.05, desc="Setting up configuration...") hf_model_repo_id = f"{hf_username}/{model_repo_name}" hf_dataset_repo_id = f"{hf_username}/{dataset_repo_name}" log = [] log.append(f"Model repo: {hf_model_repo_id}") log.append(f"Dataset repo: {hf_dataset_repo_id}") # Check if running on multiple GPUs n_gpus = torch.cuda.device_count() log.append(f"Number of GPUs available: {n_gpus}") # --- Quantization Configuration --- bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) # --- Load Base Model (with quantization) --- progress(0.1, desc="Loading base model...") try: # First update transformers to make sure we have latest version subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"]) # Now try loading with explicit config class to avoid auto-detection issues from transformers import LlamaConfig # Load config first config = LlamaConfig.from_pretrained( hf_model_repo_id, trust_remote_code=True ) # Then load model with explicit config model = AutoModelForCausalLM.from_pretrained( hf_model_repo_id, config=config, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) log.append(f"Loaded model vocab size: {model.config.vocab_size}") log.append(f"Input embedding shape: {model.get_input_embeddings().weight.shape}") except Exception as e: error_msg = f"Error loading model from Hub: {e}" log.append(error_msg) # Try with a fallback method try: log.append("Attempting alternative loading method...") # Try loading without auto detection model = AutoModelForCausalLM.from_pretrained( hf_model_repo_id, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16, # Add these to help with the loading revision="main", low_cpu_mem_usage=True, ) log.append("Alternative loading successful!") log.append(f"Loaded model vocab size: {model.config.vocab_size}") except Exception as e2: log.append(f"Alternative loading also failed: {e2}") return "\n".join(log) # Load the official Meta tokenizer for LLaMA 3 tokenizer = AutoTokenizer.from_pretrained( "meta-llama/Llama-3-8B", # Use the official Meta tokenizer use_auth_token=os.environ.get("HF_TOKEN", None) # In case it's needed ) if tokenizer is None: # Fallback to another common foundation model tokenizer print("Falling back to another tokenizer as Meta tokenizer requires auth token") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") print(f"Loaded tokenizer vocabulary size: {len(tokenizer)}") # Print information about input embeddings print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}") # Prepare model for k-bit training model = prepare_model_for_kbit_training(model) # Define LoRA configuration lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] ) peft_model = get_peft_model(model, lora_config) trainable_params = peft_model.print_trainable_parameters() log.append(f"Trainable parameters: {trainable_params}") model_to_train = peft_model # Cleanup gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # --- Load Dataset from Hub --- progress(0.2, desc="Downloading dataset...") local_download_path = "./downloaded_dataset_files" try: downloaded_repo_root = snapshot_download( repo_id=hf_dataset_repo_id, repo_type="dataset", local_dir=local_download_path, local_dir_use_symlinks=False ) log.append(f"Dataset repository content downloaded to: {downloaded_repo_root}") except Exception as e: error_msg = f"Error downloading dataset repository from Hub: {e}" log.append(error_msg) return "\n".join(log) # --- Find and load the .pt files --- progress(0.25, desc="Finding dataset files...") pairs_dir = os.path.join(downloaded_repo_root, "final_rvq_pairs") all_pair_files = glob.glob(os.path.join(pairs_dir, "*_rvq_pairs.pt")) if not all_pair_files: all_pair_files = glob.glob(os.path.join(downloaded_repo_root, "*_rvq_pairs.pt")) if not all_pair_files: error_msg = "No RVQ pair files found in expected directories" log.append(error_msg) return "\n".join(log) log.append(f"Found {len(all_pair_files)} RVQ pair files.") # --- Load data from .pt files --- progress(0.3, desc="Loading dataset files...") all_data_pairs = [] for i, file_path in enumerate(all_pair_files): progress(0.3 + (0.1 * i / len(all_pair_files)), desc=f"Loading file {i+1}/{len(all_pair_files)}") try: episode_pairs = torch.load(file_path, map_location='cpu') all_data_pairs.extend(episode_pairs) except Exception as e: log.append(f"Warning: Could not load file {file_path}: {e}") if not all_data_pairs: error_msg = "No valid data pairs were loaded" log.append(error_msg) return "\n".join(log) log.append(f"Loaded a total of {len(all_data_pairs)} training pairs into memory.") # --- Convert to HF Dataset --- progress(0.45, desc="Converting to Hugging Face Dataset...") def prepare_for_dataset(batch): output = {'input_ids': [], 'labels': []} for item in batch: output['input_ids'].append(item['input_ids'].cpu().tolist()) output['labels'].append(item['labels'].cpu().tolist()) return output chunk_size = 1000 processed_data = {'input_ids': [], 'labels': []} total_chunks = len(range(0, len(all_data_pairs), chunk_size)) for i in range(0, len(all_data_pairs), chunk_size): chunk_idx = i // chunk_size progress(0.45 + (0.1 * chunk_idx / total_chunks), desc=f"Processing chunk {chunk_idx+1}/{total_chunks}") batch = all_data_pairs[i:i + chunk_size] prepared_batch = prepare_for_dataset(batch) processed_data['input_ids'].extend(prepared_batch['input_ids']) processed_data['labels'].extend(prepared_batch['labels']) hf_dataset = Dataset.from_dict(processed_data) # Transform to get tensors back hf_dataset.set_transform(lambda batch: { 'input_ids': [torch.tensor(ids, dtype=torch.long) for ids in batch['input_ids']], 'labels': [torch.tensor(lbls, dtype=torch.long) for lbls in batch['labels']] }) train_dataset = hf_dataset # Cleanup del all_data_pairs del processed_data gc.collect() # --- Define Data Collator --- progress(0.55, desc="Defining data collator...") def seq2seq_causal_collator(features): batch = {} concatenated_input_ids = [] concatenated_labels = [] max_len = 0 # First pass: Concatenate, create masked labels, find max length for feature in features: input_ids = feature['input_ids'] labels = feature['labels'] if input_ids.dim() > 1: input_ids = input_ids.squeeze() if labels.dim() > 1: labels = labels.squeeze() context_len = input_ids.shape[0] target_len = labels.shape[0] combined_ids = torch.cat([input_ids, labels], dim=0) concatenated_input_ids.append(combined_ids) masked_labels = torch.cat([ torch.full((context_len,), -100, dtype=torch.long, device=input_ids.device), labels ], dim=0) concatenated_labels.append(masked_labels) if combined_ids.shape[0] > max_len: max_len = combined_ids.shape[0] # Second pass: Pad to max length padded_input_ids = [] padded_labels = [] input_pad_token_id = 0 label_pad_token_id = -100 for i in range(len(features)): ids = concatenated_input_ids[i] lbls = concatenated_labels[i] padding_len = max_len - ids.shape[0] padded_input_ids.append(torch.nn.functional.pad( ids, (0, padding_len), value=input_pad_token_id )) padded_labels.append(torch.nn.functional.pad( lbls, (0, padding_len), value=label_pad_token_id )) # Stack and create final batch batch['input_ids'] = torch.stack(padded_input_ids) batch['labels'] = torch.stack(padded_labels) batch['attention_mask'] = batch['input_ids'].ne(input_pad_token_id).long() return batch data_collator = seq2seq_causal_collator # --- Define Training Arguments and Initialize Trainer --- progress(0.65, desc="Setting up training configuration...") # Output directories OUTPUT_TRAINING_DIR = "./llama3-8b-rvq-qlora-finetuned-run" LOGGING_DIR = "./llama3-8b-rvq-qlora-logs-run" # Training parameters - adjusted for 4x T4 GPUs NUM_EPOCHS = int(epochs) BATCH_SIZE_PER_DEVICE = int(batch_size) # Smaller per-device batch size to avoid OOM GRAD_ACCUMULATION_STEPS = int(grad_accum_steps) LEARNING_RATE = float(learning_rate) WEIGHT_DECAY = 0.01 WARMUP_RATIO = 0.03 LR_SCHEDULER = "cosine" OPTIMIZER = "paged_adamw_8bit" # Calculate total steps and warmup steps # Total batch size is now batch_size × num_gpus × grad_accum_steps total_train_batch_size = BATCH_SIZE_PER_DEVICE * n_gpus * GRAD_ACCUMULATION_STEPS num_training_steps = math.ceil((len(train_dataset) * NUM_EPOCHS) / total_train_batch_size) num_warmup_steps = int(num_training_steps * WARMUP_RATIO) # Logging/Saving frequency steps_per_epoch = math.ceil(len(train_dataset) / total_train_batch_size) LOGGING_STEPS = max(10, steps_per_epoch // 15) SAVE_STEPS = max(50, steps_per_epoch // 10) log.append(f"Dataset size: {len(train_dataset)}") log.append(f"Number of GPUs: {n_gpus}") log.append(f"Batch size per device: {BATCH_SIZE_PER_DEVICE}") log.append(f"Gradient Accumulation steps: {GRAD_ACCUMULATION_STEPS}") log.append(f"Total train batch size (effective): {total_train_batch_size}") log.append(f"Total optimization steps: {num_training_steps}") log.append(f"Warmup steps: {num_warmup_steps}") # --- Create DeepSpeed configuration file --- progress(0.7, desc="Creating DeepSpeed configuration...") # DeepSpeed ZeRO-3 config optimized for T4 GPUs ds_config = { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu", "pin_memory": True }, "offload_param": { "device": "cpu", "pin_memory": True }, "overlap_comm": True, "contiguous_gradients": True, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", "gather_16bit_weights_on_model_save": True, "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9 }, "gradient_accumulation_steps": GRAD_ACCUMULATION_STEPS, "gradient_clipping": "auto", "steps_per_print": 10, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": False } with open("ds_config.json", "w") as f: json.dump(ds_config, f, indent=4) # Configure for multi-GPU training using DeepSpeed progress(0.75, desc="Setting up training arguments...") training_args = TrainingArguments( output_dir=OUTPUT_TRAINING_DIR, num_train_epochs=NUM_EPOCHS, per_device_train_batch_size=BATCH_SIZE_PER_DEVICE, gradient_accumulation_steps=GRAD_ACCUMULATION_STEPS, optim=OPTIMIZER, logging_dir=LOGGING_DIR, logging_strategy="steps", logging_steps=LOGGING_STEPS, save_strategy="steps", save_steps=SAVE_STEPS, save_total_limit=2, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY, warmup_steps=num_warmup_steps, lr_scheduler_type=LR_SCHEDULER, report_to="tensorboard", bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False, gradient_checkpointing=True, gradient_checkpointing_kwargs={'use_reentrant': False}, # Multi-GPU specific settings deepspeed="ds_config.json", ddp_find_unused_parameters=False, ) # --- Initialize Trainer --- progress(0.8, desc="Initializing trainer...") trainer = Trainer( model=model_to_train, args=training_args, train_dataset=train_dataset, data_collator=data_collator, ) log.append("Trainer initialized with DeepSpeed for multi-GPU training.") # --- Start Training --- # Clear cache before starting gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() try: progress(0.85, desc="Starting training...") log.append("Starting distributed training on multiple GPUs...") train_result = trainer.train() progress(0.95, desc="Saving model...") # Save final model (adapter weights) and training state final_save_path = os.path.join(training_args.output_dir, "final_checkpoint") log.append(f"Saving final model checkpoint to {final_save_path}...") trainer.save_model(final_save_path) trainer.save_state() # Log metrics metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) for key, value in metrics.items(): log.append(f"{key}: {value}") except Exception as e: error_msg = f"An error occurred during training: {e}" log.append(error_msg) return "\n".join(log) progress(1.0, desc="Training complete!") log.append("Multi-GPU training process complete.") return "\n".join(log) # Define the Gradio interface def create_interface(): with gr.Blocks(title="Llama 3 8B RVQ Fine-tuning") as demo: gr.Markdown("# Llama 3 8B RVQ LoRA Fine-tuning") gr.Markdown("Fine-tune a Llama 3 8B model with RVQ token embeddings using LoRA on multiple GPUs") with gr.Row(): with gr.Column(): hf_username = gr.Textbox(label="HuggingFace Username", value="Twelve2five") model_repo = gr.Textbox(label="Model Repository Name", value="llama-3-8b-rvq-resized") dataset_repo = gr.Textbox(label="Dataset Repository Name", value="podcast-dialogue-rvq-pairs-3items") with gr.Column(): epochs = gr.Number(label="Number of Epochs", value=1, minimum=1, maximum=10) batch_size = gr.Number(label="Batch Size per Device", value=1, minimum=1, maximum=8) grad_accum = gr.Number(label="Gradient Accumulation Steps", value=4, minimum=1, maximum=16) lr = gr.Number(label="Learning Rate", value=1e-4) start_btn = gr.Button("Start Training") output = gr.Textbox(label="Training Log", lines=20) start_btn.click( fn=train_model, inputs=[hf_username, model_repo, dataset_repo, epochs, batch_size, grad_accum, lr], outputs=output ) return demo # Create and launch the interface demo = create_interface() if __name__ == "__main__": demo.launch()