Upload 2 files
Browse files- finetune.py +111 -0
- optimize_lr.py +401 -0
finetune.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
CONTEXT_WINDOW = 1024 #has to fit in 4090
|
| 4 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 5 |
+
|
| 6 |
+
from transformers import (
|
| 7 |
+
AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
|
| 8 |
+
Trainer, DataCollatorForLanguageModeling
|
| 9 |
+
)
|
| 10 |
+
import torch
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
from huggingface_hub import login
|
| 13 |
+
|
| 14 |
+
# setup tokenizer
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B-instruct", token=HF_TOKEN)
|
| 16 |
+
if tokenizer.pad_token is None:
|
| 17 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 18 |
+
tokenizer.padding_side = "left" # better for inference
|
| 19 |
+
|
| 20 |
+
# init model with auto device mapping
|
| 21 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 22 |
+
"Zyphra/Zamba2-1.2B-instruct",
|
| 23 |
+
torch_dtype=torch.bfloat16,
|
| 24 |
+
device_map="auto" # handles multi-gpu/cpu mapping
|
| 25 |
+
)
|
| 26 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 27 |
+
|
| 28 |
+
# Load the Dutch Dolly dataset
|
| 29 |
+
dataset = load_dataset("BramVanroy/dolly-15k-dutch", split="train_sft")
|
| 30 |
+
|
| 31 |
+
def prepare_chat_format(examples):
|
| 32 |
+
chats = []
|
| 33 |
+
for messages in examples['messages']:
|
| 34 |
+
try:
|
| 35 |
+
chat = tokenizer.apply_chat_template(
|
| 36 |
+
messages,
|
| 37 |
+
tokenize=True,
|
| 38 |
+
max_length=CONTEXT_WINDOW,
|
| 39 |
+
truncation=True,
|
| 40 |
+
return_tensors=None
|
| 41 |
+
)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Error applying chat template: {e}")
|
| 44 |
+
# Fallback format if chat template fails
|
| 45 |
+
text = ""
|
| 46 |
+
for message in messages:
|
| 47 |
+
role = message["role"]
|
| 48 |
+
content = message["content"]
|
| 49 |
+
text += f"<|{role}|>\n{content}</s>\n"
|
| 50 |
+
|
| 51 |
+
chat = tokenizer(
|
| 52 |
+
text,
|
| 53 |
+
max_length=CONTEXT_WINDOW,
|
| 54 |
+
truncation=True,
|
| 55 |
+
return_tensors=None
|
| 56 |
+
)["input_ids"]
|
| 57 |
+
|
| 58 |
+
chats.append(chat)
|
| 59 |
+
return {"input_ids": chats}
|
| 60 |
+
|
| 61 |
+
# Process the dataset
|
| 62 |
+
tokenized_dataset = dataset.map(
|
| 63 |
+
prepare_chat_format,
|
| 64 |
+
batched=True,
|
| 65 |
+
remove_columns=dataset.column_names
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# training config
|
| 69 |
+
training_args = TrainingArguments(
|
| 70 |
+
output_dir="./zamba2-finetuned",
|
| 71 |
+
num_train_epochs=2,
|
| 72 |
+
per_device_train_batch_size=4,
|
| 73 |
+
save_steps=500,
|
| 74 |
+
save_total_limit=2,
|
| 75 |
+
logging_steps=100,
|
| 76 |
+
learning_rate=2e-5,
|
| 77 |
+
weight_decay=0.01,
|
| 78 |
+
fp16=False,
|
| 79 |
+
bf16=True,
|
| 80 |
+
gradient_accumulation_steps=8,
|
| 81 |
+
dataloader_num_workers=4,
|
| 82 |
+
gradient_checkpointing=True,
|
| 83 |
+
max_grad_norm=1.0,
|
| 84 |
+
warmup_steps=100
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 88 |
+
tokenizer=tokenizer,
|
| 89 |
+
mlm=False
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# custom trainer to handle device mapping
|
| 93 |
+
class CustomTrainer(Trainer):
|
| 94 |
+
def __init__(self, *args, **kwargs):
|
| 95 |
+
super().__init__(*args, **kwargs)
|
| 96 |
+
self.model = model
|
| 97 |
+
|
| 98 |
+
def _move_model_to_device(self, model, device):
|
| 99 |
+
pass # model already mapped to devices
|
| 100 |
+
|
| 101 |
+
trainer = CustomTrainer(
|
| 102 |
+
model=model,
|
| 103 |
+
args=training_args,
|
| 104 |
+
train_dataset=tokenized_dataset,
|
| 105 |
+
data_collator=data_collator
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Add explicit training and saving steps
|
| 109 |
+
trainer.train()
|
| 110 |
+
model.save_pretrained("./zamba2-finetuned-final")
|
| 111 |
+
tokenizer.save_pretrained("./zamba2-finetuned-final")
|
optimize_lr.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
|
| 4 |
+
Trainer, DataCollatorForLanguageModeling
|
| 5 |
+
)
|
| 6 |
+
import torch
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
import numpy as np
|
| 9 |
+
import gc
|
| 10 |
+
from sklearn.gaussian_process import GaussianProcessRegressor
|
| 11 |
+
from sklearn.gaussian_process.kernels import ConstantKernel, Matern
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from scipy.stats import norm
|
| 14 |
+
import warnings
|
| 15 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 16 |
+
|
| 17 |
+
from transformers import TrainerCallback
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
|
| 21 |
+
# Configuration parameters
|
| 22 |
+
num_trials = 10 # Adjust this value to control the number of optimization trials
|
| 23 |
+
DATASET = load_dataset("BramVanroy/dolly-15k-dutch", split="train_sft[:1000]")
|
| 24 |
+
CONTEXT_WINDOW = 1024
|
| 25 |
+
|
| 26 |
+
# Initialize tokenizer once
|
| 27 |
+
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B")
|
| 28 |
+
if tokenizer.pad_token is None:
|
| 29 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 30 |
+
tokenizer.padding_side = "left"
|
| 31 |
+
|
| 32 |
+
def prepare_chat_format(examples):
|
| 33 |
+
chats = []
|
| 34 |
+
for messages in examples['messages']:
|
| 35 |
+
try:
|
| 36 |
+
chat = tokenizer.apply_chat_template(
|
| 37 |
+
messages,
|
| 38 |
+
tokenize=True,
|
| 39 |
+
max_length=CONTEXT_WINDOW,
|
| 40 |
+
truncation=True,
|
| 41 |
+
return_tensors=None
|
| 42 |
+
)
|
| 43 |
+
chats.append(chat)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"Error applying chat template: {e}")
|
| 46 |
+
print("Fallback format if chat template fails")
|
| 47 |
+
text = ""
|
| 48 |
+
for message in messages:
|
| 49 |
+
role = message["role"]
|
| 50 |
+
content = message["content"]
|
| 51 |
+
text += f"<|{role}|>\n{content}</s>\n"
|
| 52 |
+
|
| 53 |
+
chat = tokenizer(
|
| 54 |
+
text,
|
| 55 |
+
max_length=CONTEXT_WINDOW,
|
| 56 |
+
truncation=True,
|
| 57 |
+
return_tensors=None
|
| 58 |
+
)["input_ids"]
|
| 59 |
+
|
| 60 |
+
chats.append(chat)
|
| 61 |
+
return {"input_ids": chats}
|
| 62 |
+
|
| 63 |
+
# Prepare dataset once
|
| 64 |
+
tokenized_dataset = DATASET.map(
|
| 65 |
+
prepare_chat_format,
|
| 66 |
+
batched=True,
|
| 67 |
+
remove_columns=DATASET.column_names
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def clear_memory():
|
| 71 |
+
"""Clear GPU memory between trials"""
|
| 72 |
+
if torch.cuda.is_available():
|
| 73 |
+
torch.cuda.empty_cache()
|
| 74 |
+
gc.collect()
|
| 75 |
+
|
| 76 |
+
class LossCallback(TrainerCallback):
|
| 77 |
+
def __init__(self):
|
| 78 |
+
self.losses = []
|
| 79 |
+
|
| 80 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 81 |
+
if logs is not None and "loss" in logs:
|
| 82 |
+
self.losses.append(logs["loss"])
|
| 83 |
+
|
| 84 |
+
def objective(trial):
|
| 85 |
+
# Clear memory from previous trial
|
| 86 |
+
clear_memory()
|
| 87 |
+
|
| 88 |
+
lr = trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True)
|
| 89 |
+
|
| 90 |
+
# Initialize model with fresh state
|
| 91 |
+
torch.manual_seed(42)
|
| 92 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 93 |
+
"Zyphra/Zamba2-1.2B",
|
| 94 |
+
torch_dtype=torch.bfloat16,
|
| 95 |
+
device_map="auto"
|
| 96 |
+
)
|
| 97 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 98 |
+
|
| 99 |
+
# Calculate steps with larger batch size
|
| 100 |
+
batch_size = 4 # Increased from 1
|
| 101 |
+
grad_accum_steps = 8 # Decreased from 32 since we increased batch size
|
| 102 |
+
effective_batch_size = batch_size * grad_accum_steps # Still 32 total
|
| 103 |
+
total_steps = len(tokenized_dataset) // effective_batch_size
|
| 104 |
+
|
| 105 |
+
# Training arguments
|
| 106 |
+
training_args = TrainingArguments(
|
| 107 |
+
output_dir=f"./optuna_runs/trial_{trial.number}",
|
| 108 |
+
num_train_epochs=1,
|
| 109 |
+
per_device_train_batch_size=batch_size, # Increased
|
| 110 |
+
gradient_accumulation_steps=grad_accum_steps, # Decreased
|
| 111 |
+
logging_steps=max(total_steps // 20, 1),
|
| 112 |
+
learning_rate=lr,
|
| 113 |
+
weight_decay=0.01,
|
| 114 |
+
fp16=False,
|
| 115 |
+
bf16=True,
|
| 116 |
+
warmup_steps=total_steps // 10,
|
| 117 |
+
save_steps=1000000,
|
| 118 |
+
save_total_limit=None,
|
| 119 |
+
report_to="none",
|
| 120 |
+
seed=42,
|
| 121 |
+
dataloader_num_workers=4, # Added for faster data loading
|
| 122 |
+
gradient_checkpointing=True, # Added to optimize memory usage
|
| 123 |
+
max_grad_norm=1.0 # Added for stability
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
print(f"\nTrial {trial.number}:")
|
| 127 |
+
print(f"Learning rate: {lr}")
|
| 128 |
+
print(f"Total steps: {total_steps}")
|
| 129 |
+
print(f"Logging every {training_args.logging_steps} steps")
|
| 130 |
+
|
| 131 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 132 |
+
tokenizer=tokenizer,
|
| 133 |
+
mlm=False
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
class CustomTrainer(Trainer):
|
| 137 |
+
def __init__(self, *args, **kwargs):
|
| 138 |
+
super().__init__(*args, **kwargs)
|
| 139 |
+
self.model = model
|
| 140 |
+
|
| 141 |
+
def _move_model_to_device(self, model, device):
|
| 142 |
+
pass
|
| 143 |
+
|
| 144 |
+
# Initialize callback
|
| 145 |
+
loss_callback = LossCallback()
|
| 146 |
+
|
| 147 |
+
trainer = CustomTrainer(
|
| 148 |
+
model=model,
|
| 149 |
+
args=training_args,
|
| 150 |
+
train_dataset=tokenized_dataset,
|
| 151 |
+
data_collator=data_collator,
|
| 152 |
+
callbacks=[loss_callback] # Use the proper callback
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
train_result = trainer.train()
|
| 157 |
+
|
| 158 |
+
# Calculate mean of last 20% of losses
|
| 159 |
+
losses = loss_callback.losses # Get losses from callback
|
| 160 |
+
n_losses = max(len(losses) // 5, 1)
|
| 161 |
+
final_losses = losses[-n_losses:]
|
| 162 |
+
mean_loss = np.mean(final_losses) if final_losses else float('inf')
|
| 163 |
+
|
| 164 |
+
# Clean up
|
| 165 |
+
del model
|
| 166 |
+
del trainer
|
| 167 |
+
clear_memory()
|
| 168 |
+
|
| 169 |
+
return mean_loss
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
print(f"Trial failed with error: {e}")
|
| 173 |
+
# Clean up on failure
|
| 174 |
+
del model
|
| 175 |
+
del trainer
|
| 176 |
+
clear_memory()
|
| 177 |
+
return float('inf')
|
| 178 |
+
|
| 179 |
+
# Create and run the study
|
| 180 |
+
study = optuna.create_study(
|
| 181 |
+
direction="minimize",
|
| 182 |
+
sampler=optuna.samplers.TPESampler(seed=42),
|
| 183 |
+
study_name="learning_rate_optimization"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
study.optimize(objective, n_trials=num_trials)
|
| 187 |
+
|
| 188 |
+
# Print results
|
| 189 |
+
print(f"\nOptimization Results ({num_trials} trials):")
|
| 190 |
+
print("Best learning rate:", study.best_params["learning_rate"])
|
| 191 |
+
print("Best loss:", study.best_value)
|
| 192 |
+
print("\nAll trials:")
|
| 193 |
+
for trial in study.trials:
|
| 194 |
+
print(f"Learning rate: {trial.params['learning_rate']:.2e}, Loss: {trial.value:.4f}")
|
| 195 |
+
|
| 196 |
+
# Save results
|
| 197 |
+
import json
|
| 198 |
+
results = {
|
| 199 |
+
"best_learning_rate": study.best_params["learning_rate"],
|
| 200 |
+
"best_loss": study.best_value,
|
| 201 |
+
"all_trials": [(trial.params["learning_rate"], trial.value) for trial in study.trials]
|
| 202 |
+
}
|
| 203 |
+
with open("lr_optimization_results.json", "w") as f:
|
| 204 |
+
json.dump(results, f, indent=4)
|
| 205 |
+
|
| 206 |
+
# Plot optimization history
|
| 207 |
+
try:
|
| 208 |
+
fig = optuna.visualization.plot_optimization_history(study)
|
| 209 |
+
fig.show()
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"Could not create visualization: {e}")
|
| 212 |
+
|
| 213 |
+
# Add sophisticated final optimization using Gaussian Process Regression
|
| 214 |
+
def optimize_final_lr(study):
|
| 215 |
+
try:
|
| 216 |
+
# Extract learning rates and losses
|
| 217 |
+
X = np.array([[trial.params['learning_rate']] for trial in study.trials])
|
| 218 |
+
y = np.array([trial.value for trial in study.trials])
|
| 219 |
+
|
| 220 |
+
# Check if we have any valid results
|
| 221 |
+
valid_mask = np.isfinite(y)
|
| 222 |
+
if not np.any(valid_mask):
|
| 223 |
+
print("No valid trials found. Returning default learning rate.")
|
| 224 |
+
return {
|
| 225 |
+
'gpr_optimal_lr': 2e-5, # default fallback
|
| 226 |
+
'ei_optimal_lr': 2e-5,
|
| 227 |
+
'predicted_loss': float('inf'),
|
| 228 |
+
'uncertainty': float('inf')
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
# Filter out infinite values
|
| 232 |
+
X = X[valid_mask]
|
| 233 |
+
y = y[valid_mask]
|
| 234 |
+
|
| 235 |
+
# Ensure we have enough points for fitting
|
| 236 |
+
if len(X) < 2:
|
| 237 |
+
print("Not enough valid trials for GPR. Returning best observed value.")
|
| 238 |
+
best_idx = np.argmin(y)
|
| 239 |
+
return {
|
| 240 |
+
'gpr_optimal_lr': float(X[best_idx][0]),
|
| 241 |
+
'ei_optimal_lr': float(X[best_idx][0]),
|
| 242 |
+
'predicted_loss': float(y[best_idx]),
|
| 243 |
+
'uncertainty': float('inf')
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
# Transform to log space
|
| 247 |
+
X_log = np.log10(X)
|
| 248 |
+
|
| 249 |
+
# Normalize y values
|
| 250 |
+
y_mean = np.mean(y)
|
| 251 |
+
y_std = np.std(y)
|
| 252 |
+
if y_std == 0:
|
| 253 |
+
y_std = 1
|
| 254 |
+
y_normalized = (y - y_mean) / y_std
|
| 255 |
+
|
| 256 |
+
# Define kernel
|
| 257 |
+
kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
|
| 258 |
+
|
| 259 |
+
# Fit Gaussian Process
|
| 260 |
+
gpr = GaussianProcessRegressor(
|
| 261 |
+
kernel=kernel,
|
| 262 |
+
n_restarts_optimizer=10,
|
| 263 |
+
random_state=42,
|
| 264 |
+
normalize_y=False # we're manually normalizing
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
try:
|
| 268 |
+
gpr.fit(X_log, y_normalized)
|
| 269 |
+
except np.linalg.LinAlgError:
|
| 270 |
+
print("GPR fitting failed. Returning best observed value.")
|
| 271 |
+
best_idx = np.argmin(y)
|
| 272 |
+
return {
|
| 273 |
+
'gpr_optimal_lr': float(X[best_idx][0]),
|
| 274 |
+
'ei_optimal_lr': float(X[best_idx][0]),
|
| 275 |
+
'predicted_loss': float(y[best_idx]),
|
| 276 |
+
'uncertainty': float('inf')
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
# Create fine grid of points for prediction
|
| 280 |
+
X_pred_log = np.linspace(np.log10(X.min()), np.log10(X.max()), 1000).reshape(-1, 1)
|
| 281 |
+
|
| 282 |
+
# Predict mean and std
|
| 283 |
+
y_pred_normalized, sigma = gpr.predict(X_pred_log, return_std=True)
|
| 284 |
+
|
| 285 |
+
# Denormalize predictions
|
| 286 |
+
y_pred = y_pred_normalized * y_std + y_mean
|
| 287 |
+
sigma = sigma * y_std
|
| 288 |
+
|
| 289 |
+
# Find the point with lowest predicted value
|
| 290 |
+
best_idx = np.argmin(y_pred)
|
| 291 |
+
optimal_lr = 10 ** X_pred_log[best_idx, 0]
|
| 292 |
+
|
| 293 |
+
# Calculate acquisition function (Expected Improvement)
|
| 294 |
+
best_f = np.min(y)
|
| 295 |
+
Z = (best_f - y_pred) / (sigma + 1e-9) # add small constant to prevent division by zero
|
| 296 |
+
ei = sigma * (Z * norm.cdf(Z) + norm.pdf(Z))
|
| 297 |
+
|
| 298 |
+
# Find point with highest expected improvement
|
| 299 |
+
ei_best_idx = np.argmax(ei)
|
| 300 |
+
ei_optimal_lr = 10 ** X_pred_log[ei_best_idx, 0]
|
| 301 |
+
|
| 302 |
+
return {
|
| 303 |
+
'gpr_optimal_lr': float(optimal_lr),
|
| 304 |
+
'ei_optimal_lr': float(ei_optimal_lr),
|
| 305 |
+
'predicted_loss': float(y_pred[best_idx]),
|
| 306 |
+
'uncertainty': float(sigma[best_idx])
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
except Exception as e:
|
| 310 |
+
print(f"Optimization failed with error: {e}")
|
| 311 |
+
return {
|
| 312 |
+
'gpr_optimal_lr': 2e-5, # default fallback
|
| 313 |
+
'ei_optimal_lr': 2e-5,
|
| 314 |
+
'predicted_loss': float('inf'),
|
| 315 |
+
'uncertainty': float('inf')
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
# Run final optimization and handle potential failures
|
| 319 |
+
try:
|
| 320 |
+
final_optimization = optimize_final_lr(study)
|
| 321 |
+
print("\nAdvanced Optimization Results:")
|
| 322 |
+
print(f"GPR Optimal Learning Rate: {final_optimization['gpr_optimal_lr']:.2e}")
|
| 323 |
+
print(f"Expected Improvement Optimal Learning Rate: {final_optimization['ei_optimal_lr']:.2e}")
|
| 324 |
+
print(f"Predicted Loss: {final_optimization['predicted_loss']:.4f}")
|
| 325 |
+
print(f"Uncertainty: {final_optimization['uncertainty']:.4f}")
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(f"Final optimization failed: {e}")
|
| 328 |
+
final_optimization = {
|
| 329 |
+
'gpr_optimal_lr': 2e-5,
|
| 330 |
+
'ei_optimal_lr': 2e-5,
|
| 331 |
+
'predicted_loss': float('inf'),
|
| 332 |
+
'uncertainty': float('inf')
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
# Save extended results
|
| 336 |
+
results.update({
|
| 337 |
+
"gpr_optimal_lr": float(final_optimization['gpr_optimal_lr']),
|
| 338 |
+
"ei_optimal_lr": float(final_optimization['ei_optimal_lr']),
|
| 339 |
+
"predicted_loss": float(final_optimization['predicted_loss']),
|
| 340 |
+
"uncertainty": float(final_optimization['uncertainty'])
|
| 341 |
+
})
|
| 342 |
+
|
| 343 |
+
# Visualization of the GPR results
|
| 344 |
+
def plot_gpr_results(study, final_optimization):
|
| 345 |
+
# Extract data and filter out infinite values
|
| 346 |
+
X = np.array([[trial.params['learning_rate']] for trial in study.trials])
|
| 347 |
+
y = np.array([trial.value for trial in study.trials])
|
| 348 |
+
|
| 349 |
+
# Create mask for finite values
|
| 350 |
+
finite_mask = np.isfinite(y)
|
| 351 |
+
X = X[finite_mask]
|
| 352 |
+
y = y[finite_mask]
|
| 353 |
+
|
| 354 |
+
# Check if we have enough valid points
|
| 355 |
+
if len(X) < 2:
|
| 356 |
+
print("Not enough valid points for GPR visualization")
|
| 357 |
+
return
|
| 358 |
+
|
| 359 |
+
# Create prediction points
|
| 360 |
+
X_pred = np.logspace(np.log10(X.min()), np.log10(X.max()), 100).reshape(-1, 1)
|
| 361 |
+
X_pred_log = np.log10(X_pred)
|
| 362 |
+
|
| 363 |
+
# Fit GPR for plotting
|
| 364 |
+
kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
|
| 365 |
+
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, random_state=42)
|
| 366 |
+
gpr.fit(np.log10(X), y)
|
| 367 |
+
|
| 368 |
+
# Predict mean and std
|
| 369 |
+
y_pred, sigma = gpr.predict(X_pred_log, return_std=True)
|
| 370 |
+
|
| 371 |
+
plt.figure(figsize=(12, 6))
|
| 372 |
+
plt.semilogx(X, y, 'ko', label='Valid Trials', markersize=8)
|
| 373 |
+
plt.semilogx(X_pred, y_pred, 'b-', label='GPR Mean')
|
| 374 |
+
plt.fill_between(X_pred.ravel(),
|
| 375 |
+
y_pred - 2*sigma,
|
| 376 |
+
y_pred + 2*sigma,
|
| 377 |
+
color='blue',
|
| 378 |
+
alpha=0.2,
|
| 379 |
+
label='95% Confidence')
|
| 380 |
+
|
| 381 |
+
# Only plot optimal lines if they are finite
|
| 382 |
+
if np.isfinite(final_optimization['gpr_optimal_lr']):
|
| 383 |
+
plt.axvline(final_optimization['gpr_optimal_lr'], color='r', linestyle='--',
|
| 384 |
+
label='GPR Optimal LR')
|
| 385 |
+
if np.isfinite(final_optimization['ei_optimal_lr']):
|
| 386 |
+
plt.axvline(final_optimization['ei_optimal_lr'], color='g', linestyle='--',
|
| 387 |
+
label='EI Optimal LR')
|
| 388 |
+
|
| 389 |
+
plt.xlabel('Learning Rate')
|
| 390 |
+
plt.ylabel('Loss')
|
| 391 |
+
plt.title('Learning Rate Optimization Results with GPR')
|
| 392 |
+
plt.legend()
|
| 393 |
+
plt.grid(True)
|
| 394 |
+
plt.savefig('lr_optimization_plot.png', dpi=300, bbox_inches='tight')
|
| 395 |
+
plt.close()
|
| 396 |
+
|
| 397 |
+
plot_gpr_results(study, final_optimization)
|
| 398 |
+
|
| 399 |
+
# Save all results
|
| 400 |
+
with open("lr_optimization_results.json", "w") as f:
|
| 401 |
+
json.dump(results, f, indent=4)
|