import os from argparse import Namespace import yaml class DataConfig: def __init__( self, train_data_path: str, valid_data_path: str, batch_size: int, num_data_workers: int, prefetch_factor: int, time_delta_input_minutes: list[int], n_input_timestamps: int | None = None, pooling: int | None = None, random_vert_flip: bool = False, **kwargs, ): self.__dict__.update(kwargs) self.train_data_path = train_data_path self.valid_data_path = valid_data_path self.batch_size = batch_size self.num_data_workers = num_data_workers self.prefetch_factor = prefetch_factor self.time_delta_input_minutes = sorted(time_delta_input_minutes) self.n_input_timestamps = n_input_timestamps self.pooling = pooling self.random_vert_flip = random_vert_flip if self.n_input_timestamps is None: self.n_input_timestamps = len(self.time_delta_input_minutes) assert ( self.n_input_timestamps > 0 ), "Number of input timestamps must be greater than 0." assert self.n_input_timestamps <= len(self.time_delta_input_minutes), ( f"Cannot sample {self.n_input_timestamps} from list of " f"{self.time_delta_input_minutes} input timestamps." ) def to_dict(self): return self.__dict__ @staticmethod def from_argparse(args: Namespace): return DataConfig(**args.__dict__) def __str__(self): return ( f"Training index: {self.train_data_path}, " f"Validation index: {self.valid_data_path}, " ) def __repr__(self): return ( f"Training index: {self.train_data_path}, " f"Validation index: {self.valid_data_path}, " ) class ModelConfig: def __init__( self, # enc_num_layers: int, # enc_num_heads: int, # enc_embed_size: int, # dec_num_layers: int, # dec_num_heads: int, # dec_embed_size: int, # mask_ratio: float, **kwargs, ): self.__dict__.update(kwargs) # self.enc_num_layers = enc_num_layers # self.enc_num_heads = enc_num_heads # self.enc_embed_size = enc_embed_size # self.dec_num_layers = dec_num_layers # self.dec_num_heads = dec_num_heads # self.dec_embed_size = dec_embed_size # self.mlp_ratio = 0.0 # self.mask_ratio = mask_ratio self.__dict__.update(kwargs) def to_dict(self): return self.__dict__ @staticmethod def from_argparse(args: Namespace): return ModelConfig(**args.__dict__) @property def encoder_d_ff(self): return int(self.enc_embed_size * self.mlp_ratio) @property def decoder_d_ff(self): return int(self.dec_embed_size * self.mlp_ratio) def __str__(self): return ( f"Input channels: {self.model.in_channels}, " f"Encoder (L, H, E): {[self.enc_num_layers, self.enc_num_heads, self.enc_embed_size]}, " f"Decoder (L, H, E): {[self.dec_num_layers, self.dec_num_heads, self.dec_embed_size]}" ) def __repr__(self): return ( f"Input channels: {self.model.in_channels}, " f"Encoder (L, H, E): {[self.enc_num_layers, self.enc_num_heads, self.enc_embed_size]}, " f"Decoder (L, H, E): {[self.dec_num_layers, self.dec_num_heads, self.dec_embed_size]}" ) class OptimizerConfig: def __init__( self, warm_up_steps: int, max_epochs: int, learning_rate: float, min_lr: float, ): self.warm_up_steps = warm_up_steps self.max_epochs = max_epochs self.learning_rate = learning_rate self.min_lr = min_lr def to_dict(self): return self.__dict__ @staticmethod def from_argparse(args: Namespace): return ModelConfig(**args.__dict__) def __str__(self): return ( f"Epochs: {self.max_epochs}, " f"LR: {[self.learning_rate, self.min_lr]}, " f"Warm up: {self.warm_up_steps}," ) def __repr__(self): return ( f"Epochs: {self.max_epochs}, " f"LR: {[self.learning_rate, self.min_lr]}, " f"Warm up: {self.warm_up_steps}," ) class ExperimentConfig: def __init__( self, job_id: str, data_config: DataConfig, model_config: ModelConfig, optimizer_config: OptimizerConfig, path_experiment: str, parallelism: str, from_checkpoint: str | None = None, **kwargs, ): # additional experiment parameters used in downstream tasks self.__dict__.update(kwargs) self.job_id = job_id self.data = data_config self.model = model_config self.optimizer = optimizer_config self.path_experiment = path_experiment self.from_checkpoint = from_checkpoint self.parallelism = parallelism assert self.model.in_channels == len(self.data.channels), ( f"Number of model input channels ({self.model.in_channels}) must be " f"equal to number of input variables ({len(self.data.channels)})." ) if self.model.time_embedding["type"] == "linear": assert ( self.model.time_embedding["time_dim"] == self.data.n_input_timestamps ), "Time dimension of linear embedding must be equal to number of input timestamps." if self.rollout_steps > 0: assert self.data.n_input_timestamps == len( self.data.time_delta_input_minutes ), "Rollout does not support randomly sampled input timestamps." metrics_channels = [] for field1, value1 in self.metrics["train_metrics_config"].items(): for field2, value2 in self.metrics["train_metrics_config"][field1].items(): if field2 == "metrics": for metric_definition in value2: split_metric_definition = metric_definition.split(":") channels = ( split_metric_definition[2] if len(split_metric_definition) > 2 else None ) if channels is not None: metrics_channels = metrics_channels + channels.split("...") for field1, value1 in self.metrics["validation_metrics_config"].items(): for field2, value2 in self.metrics["validation_metrics_config"][ field1 ].items(): if field2 == "metrics": for metric_definition in value2: split_metric_definition = metric_definition.split(":") channels = ( split_metric_definition[2] if len(split_metric_definition) > 2 else None ) if channels is not None: metrics_channels = metrics_channels + channels.replace( "...", "&" ).split("&") assert set(metrics_channels).issubset(self.data.channels), ( f"{set(metrics_channels).difference(self.data.channels)} " f"not part of data input channels." ) assert self.parallelism in [ "ddp", "fsdp", ], 'Valid choices for `parallelism` are "ddp" and "fsdp".' @property def path_checkpoint(self) -> str: if self.path_experiment == "": return os.path.join(self.path_weights, "train", "checkpoint.pt") else: return os.path.join( os.path.dirname(self.path_experiment), "weights", "train", "checkpoint.pt", ) @property def path_weights(self) -> str: return os.path.join(self.path_experiment, self.make_suffix_path(), "weights") @property def path_states(self) -> str: return os.path.join(self.path_experiment, self.make_suffix_path(), "states") def to_dict(self): d = self.__dict__.copy() d["model"] = self.model.to_dict() d["data"] = self.data.to_dict() return d @staticmethod def from_argparse(args: Namespace): return ExperimentConfig( data_config=DataConfig.from_argparse(args), model_config=ModelConfig.from_argparse(args), optimizer_config=OptimizerConfig.from_argparse(args), **args.__dict__, ) @staticmethod def from_dict(params: dict): return ExperimentConfig( data_config=DataConfig(**params["data"]), model_config=ModelConfig(**params["model"]), optimizer_config=OptimizerConfig(**params["optimizer"]), **params, ) def make_folder_name(self) -> str: param_folder = "wpt-c1-s1" return param_folder def make_suffix_path(self) -> str: return os.path.join(self.job_id) def __str__(self): return ( f"ID: {self.job_id}, " f"Epochs: {self.optimizer.max_epochs}, " f"Batch size: {self.data.batch_size}, " f"LR: {[self.optimizer.learning_rate, self.optimizer.min_lr]}, " f"Warm up: {self.optimizer.warm_up_steps}," f"DL workers: {self.data.num_data_workers}," f"Parallelism: {self.parallelism}" ) def __repr__(self): return ( f"ID: {self.job_id}, " f"Epochs: {self.optimizer.max_epochs}, " f"Batch size: {self.data.batch_size}, " f"LR: {[self.optimizer.learning_rate, self.optimizer.min_lr]}, " f"Warm up: {self.optimizer.warm_up_steps}," f"DL workers: {self.data.num_data_workers}," f"Parallelism: {self.parallelism}" ) def get_config( config_path: str, ) -> ExperimentConfig: cfg = yaml.safe_load(open(config_path, "r")) cfg["data"]["scalers"] = yaml.safe_load(open(cfg["data"]["scalers_path"], "r")) return ExperimentConfig.from_dict(params=cfg)