|
import os |
|
from argparse import Namespace |
|
|
|
import yaml |
|
|
|
|
|
class DataConfig: |
|
def __init__( |
|
self, |
|
train_data_path: str, |
|
valid_data_path: str, |
|
batch_size: int, |
|
num_data_workers: int, |
|
prefetch_factor: int, |
|
time_delta_input_minutes: list[int], |
|
n_input_timestamps: int | None = None, |
|
pooling: int | None = None, |
|
random_vert_flip: bool = False, |
|
**kwargs, |
|
): |
|
self.__dict__.update(kwargs) |
|
|
|
self.train_data_path = train_data_path |
|
self.valid_data_path = valid_data_path |
|
self.batch_size = batch_size |
|
self.num_data_workers = num_data_workers |
|
self.prefetch_factor = prefetch_factor |
|
self.time_delta_input_minutes = sorted(time_delta_input_minutes) |
|
self.n_input_timestamps = n_input_timestamps |
|
self.pooling = pooling |
|
self.random_vert_flip = random_vert_flip |
|
|
|
if self.n_input_timestamps is None: |
|
self.n_input_timestamps = len(self.time_delta_input_minutes) |
|
|
|
assert ( |
|
self.n_input_timestamps > 0 |
|
), "Number of input timestamps must be greater than 0." |
|
assert self.n_input_timestamps <= len(self.time_delta_input_minutes), ( |
|
f"Cannot sample {self.n_input_timestamps} from list of " |
|
f"{self.time_delta_input_minutes} input timestamps." |
|
) |
|
|
|
def to_dict(self): |
|
return self.__dict__ |
|
|
|
@staticmethod |
|
def from_argparse(args: Namespace): |
|
return DataConfig(**args.__dict__) |
|
|
|
def __str__(self): |
|
return ( |
|
f"Training index: {self.train_data_path}, " |
|
f"Validation index: {self.valid_data_path}, " |
|
) |
|
|
|
def __repr__(self): |
|
return ( |
|
f"Training index: {self.train_data_path}, " |
|
f"Validation index: {self.valid_data_path}, " |
|
) |
|
|
|
|
|
class ModelConfig: |
|
def __init__( |
|
self, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
**kwargs, |
|
): |
|
self.__dict__.update(kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.__dict__.update(kwargs) |
|
|
|
def to_dict(self): |
|
return self.__dict__ |
|
|
|
@staticmethod |
|
def from_argparse(args: Namespace): |
|
return ModelConfig(**args.__dict__) |
|
|
|
@property |
|
def encoder_d_ff(self): |
|
return int(self.enc_embed_size * self.mlp_ratio) |
|
|
|
@property |
|
def decoder_d_ff(self): |
|
return int(self.dec_embed_size * self.mlp_ratio) |
|
|
|
def __str__(self): |
|
return ( |
|
f"Input channels: {self.model.in_channels}, " |
|
f"Encoder (L, H, E): {[self.enc_num_layers, self.enc_num_heads, self.enc_embed_size]}, " |
|
f"Decoder (L, H, E): {[self.dec_num_layers, self.dec_num_heads, self.dec_embed_size]}" |
|
) |
|
|
|
def __repr__(self): |
|
return ( |
|
f"Input channels: {self.model.in_channels}, " |
|
f"Encoder (L, H, E): {[self.enc_num_layers, self.enc_num_heads, self.enc_embed_size]}, " |
|
f"Decoder (L, H, E): {[self.dec_num_layers, self.dec_num_heads, self.dec_embed_size]}" |
|
) |
|
|
|
|
|
class OptimizerConfig: |
|
def __init__( |
|
self, |
|
warm_up_steps: int, |
|
max_epochs: int, |
|
learning_rate: float, |
|
min_lr: float, |
|
): |
|
self.warm_up_steps = warm_up_steps |
|
self.max_epochs = max_epochs |
|
self.learning_rate = learning_rate |
|
self.min_lr = min_lr |
|
|
|
def to_dict(self): |
|
return self.__dict__ |
|
|
|
@staticmethod |
|
def from_argparse(args: Namespace): |
|
return ModelConfig(**args.__dict__) |
|
|
|
def __str__(self): |
|
return ( |
|
f"Epochs: {self.max_epochs}, " |
|
f"LR: {[self.learning_rate, self.min_lr]}, " |
|
f"Warm up: {self.warm_up_steps}," |
|
) |
|
|
|
def __repr__(self): |
|
return ( |
|
f"Epochs: {self.max_epochs}, " |
|
f"LR: {[self.learning_rate, self.min_lr]}, " |
|
f"Warm up: {self.warm_up_steps}," |
|
) |
|
|
|
|
|
class ExperimentConfig: |
|
def __init__( |
|
self, |
|
job_id: str, |
|
data_config: DataConfig, |
|
model_config: ModelConfig, |
|
optimizer_config: OptimizerConfig, |
|
path_experiment: str, |
|
parallelism: str, |
|
from_checkpoint: str | None = None, |
|
**kwargs, |
|
): |
|
|
|
self.__dict__.update(kwargs) |
|
|
|
self.job_id = job_id |
|
self.data = data_config |
|
self.model = model_config |
|
self.optimizer = optimizer_config |
|
self.path_experiment = path_experiment |
|
self.from_checkpoint = from_checkpoint |
|
self.parallelism = parallelism |
|
|
|
assert self.model.in_channels == len(self.data.channels), ( |
|
f"Number of model input channels ({self.model.in_channels}) must be " |
|
f"equal to number of input variables ({len(self.data.channels)})." |
|
) |
|
if self.model.time_embedding["type"] == "linear": |
|
assert ( |
|
self.model.time_embedding["time_dim"] == self.data.n_input_timestamps |
|
), "Time dimension of linear embedding must be equal to number of input timestamps." |
|
if self.rollout_steps > 0: |
|
assert self.data.n_input_timestamps == len( |
|
self.data.time_delta_input_minutes |
|
), "Rollout does not support randomly sampled input timestamps." |
|
|
|
metrics_channels = [] |
|
for field1, value1 in self.metrics["train_metrics_config"].items(): |
|
for field2, value2 in self.metrics["train_metrics_config"][field1].items(): |
|
if field2 == "metrics": |
|
for metric_definition in value2: |
|
split_metric_definition = metric_definition.split(":") |
|
channels = ( |
|
split_metric_definition[2] |
|
if len(split_metric_definition) > 2 |
|
else None |
|
) |
|
if channels is not None: |
|
metrics_channels = metrics_channels + channels.split("...") |
|
|
|
for field1, value1 in self.metrics["validation_metrics_config"].items(): |
|
for field2, value2 in self.metrics["validation_metrics_config"][ |
|
field1 |
|
].items(): |
|
if field2 == "metrics": |
|
for metric_definition in value2: |
|
split_metric_definition = metric_definition.split(":") |
|
channels = ( |
|
split_metric_definition[2] |
|
if len(split_metric_definition) > 2 |
|
else None |
|
) |
|
if channels is not None: |
|
metrics_channels = metrics_channels + channels.replace( |
|
"...", "&" |
|
).split("&") |
|
|
|
assert set(metrics_channels).issubset(self.data.channels), ( |
|
f"{set(metrics_channels).difference(self.data.channels)} " |
|
f"not part of data input channels." |
|
) |
|
|
|
assert self.parallelism in [ |
|
"ddp", |
|
"fsdp", |
|
], 'Valid choices for `parallelism` are "ddp" and "fsdp".' |
|
|
|
@property |
|
def path_checkpoint(self) -> str: |
|
if self.path_experiment == "": |
|
return os.path.join(self.path_weights, "train", "checkpoint.pt") |
|
else: |
|
return os.path.join( |
|
os.path.dirname(self.path_experiment), |
|
"weights", |
|
"train", |
|
"checkpoint.pt", |
|
) |
|
|
|
@property |
|
def path_weights(self) -> str: |
|
return os.path.join(self.path_experiment, self.make_suffix_path(), "weights") |
|
|
|
@property |
|
def path_states(self) -> str: |
|
return os.path.join(self.path_experiment, self.make_suffix_path(), "states") |
|
|
|
def to_dict(self): |
|
d = self.__dict__.copy() |
|
d["model"] = self.model.to_dict() |
|
d["data"] = self.data.to_dict() |
|
|
|
return d |
|
|
|
@staticmethod |
|
def from_argparse(args: Namespace): |
|
return ExperimentConfig( |
|
data_config=DataConfig.from_argparse(args), |
|
model_config=ModelConfig.from_argparse(args), |
|
optimizer_config=OptimizerConfig.from_argparse(args), |
|
**args.__dict__, |
|
) |
|
|
|
@staticmethod |
|
def from_dict(params: dict): |
|
return ExperimentConfig( |
|
data_config=DataConfig(**params["data"]), |
|
model_config=ModelConfig(**params["model"]), |
|
optimizer_config=OptimizerConfig(**params["optimizer"]), |
|
**params, |
|
) |
|
|
|
def make_folder_name(self) -> str: |
|
param_folder = "wpt-c1-s1" |
|
return param_folder |
|
|
|
def make_suffix_path(self) -> str: |
|
return os.path.join(self.job_id) |
|
|
|
def __str__(self): |
|
return ( |
|
f"ID: {self.job_id}, " |
|
f"Epochs: {self.optimizer.max_epochs}, " |
|
f"Batch size: {self.data.batch_size}, " |
|
f"LR: {[self.optimizer.learning_rate, self.optimizer.min_lr]}, " |
|
f"Warm up: {self.optimizer.warm_up_steps}," |
|
f"DL workers: {self.data.num_data_workers}," |
|
f"Parallelism: {self.parallelism}" |
|
) |
|
|
|
def __repr__(self): |
|
return ( |
|
f"ID: {self.job_id}, " |
|
f"Epochs: {self.optimizer.max_epochs}, " |
|
f"Batch size: {self.data.batch_size}, " |
|
f"LR: {[self.optimizer.learning_rate, self.optimizer.min_lr]}, " |
|
f"Warm up: {self.optimizer.warm_up_steps}," |
|
f"DL workers: {self.data.num_data_workers}," |
|
f"Parallelism: {self.parallelism}" |
|
) |
|
|
|
|
|
def get_config( |
|
config_path: str, |
|
) -> ExperimentConfig: |
|
cfg = yaml.safe_load(open(config_path, "r")) |
|
cfg["data"]["scalers"] = yaml.safe_load(open(cfg["data"]["scalers_path"], "r")) |
|
return ExperimentConfig.from_dict(params=cfg) |
|
|