Spaces:
Runtime error
Runtime error
| from typing import Any, Optional, Union, Tuple, Dict, List | |
| import os | |
| import random | |
| import math | |
| import time | |
| import numpy as np | |
| from tqdm.auto import tqdm, trange | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import jax | |
| import jax.numpy as jnp | |
| import optax | |
| from flax import jax_utils, traverse_util | |
| from flax.core.frozen_dict import FrozenDict | |
| from flax.training.train_state import TrainState | |
| from flax.training.common_utils import shard | |
| # convert 2D -> 3D | |
| from diffusers import FlaxUNet2DConditionModel | |
| # inference test, run on these on cpu | |
| from diffusers import AutoencoderKL | |
| from diffusers.schedulers.scheduling_ddim_flax import FlaxDDIMScheduler, DDIMSchedulerState | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from PIL import Image | |
| from .flax_unet_pseudo3d_condition import UNetPseudo3DConditionModel | |
| def seed_all(seed: int) -> jax.random.PRNGKeyArray: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| rng = jax.random.PRNGKey(seed) | |
| return rng | |
| def count_params( | |
| params: Union[Dict[str, Any], | |
| FrozenDict[str, Any]], | |
| filter_name: Optional[str] = None | |
| ) -> int: | |
| p: Dict[Tuple[str], jax.Array] = traverse_util.flatten_dict(params) | |
| cc = 0 | |
| for k in p: | |
| if filter_name is not None: | |
| if filter_name in ' '.join(k): | |
| cc += len(p[k].flatten()) | |
| else: | |
| cc += len(p[k].flatten()) | |
| return cc | |
| def map_2d_to_pseudo3d( | |
| params2d: Dict[str, Any], | |
| params3d: Dict[str, Any], | |
| verbose: bool = True | |
| ) -> Dict[str, Any]: | |
| params2d = traverse_util.flatten_dict(params2d) | |
| params3d = traverse_util.flatten_dict(params3d) | |
| new_params = dict() | |
| for k in params3d: | |
| if 'spatial_conv' in k: | |
| k2d = list(k) | |
| k2d.remove('spatial_conv') | |
| k2d = tuple(k2d) | |
| if verbose: | |
| tqdm.write(f'Spatial: {k} <- {k2d}') | |
| p = params2d[k2d] | |
| elif k not in params2d: | |
| if verbose: | |
| tqdm.write(f'Missing: {k}') | |
| p = params3d[k] | |
| else: | |
| p = params2d[k] | |
| assert p.shape == params3d[k].shape, f'shape mismatch: {k}: {p.shape} != {params3d[k].shape}' | |
| new_params[k] = p | |
| new_params = traverse_util.unflatten_dict(new_params) | |
| return new_params | |
| class FlaxTrainerUNetPseudo3D: | |
| def __init__(self, | |
| model_path: str, | |
| from_pt: bool = True, | |
| convert2d: bool = False, | |
| sample_size: Tuple[int, int] = (64, 64), | |
| seed: int = 0, | |
| dtype: str = 'float32', | |
| param_dtype: str = 'float32', | |
| only_temporal: bool = True, | |
| use_memory_efficient_attention = False, | |
| verbose: bool = True | |
| ) -> None: | |
| self.verbose = verbose | |
| self.tracker: Optional['wandb.sdk.wandb_run.Run'] = None | |
| self._use_wandb: bool = False | |
| self._tracker_meta: Dict[str, Union[float, int]] = { | |
| 't00': 0.0, | |
| 't0': 0.0, | |
| 'step0': 0 | |
| } | |
| self.log('Init JAX') | |
| self.num_devices = jax.device_count() | |
| self.log(f'Device count: {self.num_devices}') | |
| self.seed = seed | |
| self.rng: jax.random.PRNGKeyArray = seed_all(self.seed) | |
| self.sample_size = sample_size | |
| if dtype == 'float32': | |
| self.dtype = jnp.float32 | |
| elif dtype == 'bfloat16': | |
| self.dtype = jnp.bfloat16 | |
| elif dtype == 'float16': | |
| self.dtype = jnp.float16 | |
| else: | |
| raise ValueError(f'unknown type: {dtype}') | |
| self.dtype_str: str = dtype | |
| if param_dtype not in ['float32', 'bfloat16', 'float16']: | |
| raise ValueError(f'unknown parameter type: {param_dtype}') | |
| self.param_dtype = param_dtype | |
| self._load_models( | |
| model_path = model_path, | |
| convert2d = convert2d, | |
| from_pt = from_pt, | |
| use_memory_efficient_attention = use_memory_efficient_attention | |
| ) | |
| self._mark_parameters(only_temporal = only_temporal) | |
| # optionally for validation + sampling | |
| self.tokenizer: Optional[CLIPTokenizer] = None | |
| self.text_encoder: Optional[CLIPTextModel] = None | |
| self.vae: Optional[AutoencoderKL] = None | |
| self.ddim: Optional[Tuple[FlaxDDIMScheduler, DDIMSchedulerState]] = None | |
| def log(self, message: Any) -> None: | |
| if self.verbose and jax.process_index() == 0: | |
| tqdm.write(str(message)) | |
| def log_metrics(self, metrics: dict, step: int, epoch: int) -> None: | |
| if jax.process_index() > 0 or (not self.verbose and self.tracker is None): | |
| return | |
| now = time.monotonic() | |
| log_data = { | |
| 'train/step': step, | |
| 'train/epoch': epoch, | |
| 'train/steps_per_sec': (step - self._tracker_meta['step0']) / (now - self._tracker_meta['t0']), | |
| **{ f'train/{k}': v for k, v in metrics.items() } | |
| } | |
| self._tracker_meta['t0'] = now | |
| self._tracker_meta['step0'] = step | |
| self.log(log_data) | |
| if self.tracker is not None: | |
| self.tracker.log(log_data, step = step) | |
| def enable_wandb(self, enable: bool = True) -> None: | |
| self._use_wandb = enable | |
| def _setup_wandb(self, config: Dict[str, Any] = dict()) -> None: | |
| import wandb | |
| import wandb.sdk | |
| self.tracker: wandb.sdk.wandb_run.Run = wandb.init( | |
| config = config, | |
| settings = wandb.sdk.Settings( | |
| username = 'anon', | |
| host = 'anon', | |
| email = 'anon', | |
| root_dir = 'anon', | |
| _executable = 'anon', | |
| _disable_stats = True, | |
| _disable_meta = True, | |
| disable_code = True, | |
| disable_git = True | |
| ) # pls don't log sensitive data like system user names. also, fuck you for even trying. | |
| ) | |
| def _init_tracker_meta(self) -> None: | |
| now = time.monotonic() | |
| self._tracker_meta = { | |
| 't00': now, | |
| 't0': now, | |
| 'step0': 0 | |
| } | |
| def _load_models(self, | |
| model_path: str, | |
| convert2d: bool, | |
| from_pt: bool, | |
| use_memory_efficient_attention: bool | |
| ) -> None: | |
| self.log(f'Load pretrained from {model_path}') | |
| if convert2d: | |
| self.log(' Convert 2D model to Pseudo3D') | |
| self.log(' Initiate Pseudo3D model') | |
| config = UNetPseudo3DConditionModel.load_config(model_path, subfolder = 'unet') | |
| model = UNetPseudo3DConditionModel.from_config( | |
| config, | |
| sample_size = self.sample_size, | |
| dtype = self.dtype, | |
| param_dtype = self.param_dtype, | |
| use_memory_efficient_attention = use_memory_efficient_attention | |
| ) | |
| params: Dict[str, Any] = model.init_weights(self.rng).unfreeze() | |
| self.log(' Load 2D model') | |
| model2d, params2d = FlaxUNet2DConditionModel.from_pretrained( | |
| model_path, | |
| subfolder = 'unet', | |
| dtype = self.dtype, | |
| from_pt = from_pt | |
| ) | |
| self.log(' Map 2D -> 3D') | |
| params = map_2d_to_pseudo3d(params2d, params, verbose = self.verbose) | |
| del params2d | |
| del model2d | |
| del config | |
| else: | |
| model, params = UNetPseudo3DConditionModel.from_pretrained( | |
| model_path, | |
| subfolder = 'unet', | |
| from_pt = from_pt, | |
| sample_size = self.sample_size, | |
| dtype = self.dtype, | |
| param_dtype = self.param_dtype, | |
| use_memory_efficient_attention = use_memory_efficient_attention | |
| ) | |
| self.log(f'Cast parameters to {model.param_dtype}') | |
| if model.param_dtype == 'float32': | |
| params = model.to_fp32(params) | |
| elif model.param_dtype == 'float16': | |
| params = model.to_fp16(params) | |
| elif model.param_dtype == 'bfloat16': | |
| params = model.to_bf16(params) | |
| self.pretrained_model = model_path | |
| self.model: UNetPseudo3DConditionModel = model | |
| self.params: FrozenDict[str, Any] = FrozenDict(params) | |
| def _mark_parameters(self, only_temporal: bool) -> None: | |
| self.log('Mark training parameters') | |
| if only_temporal: | |
| self.log('Only training temporal layers') | |
| if only_temporal: | |
| param_partitions = traverse_util.path_aware_map( | |
| lambda path, _: 'trainable' if 'temporal' in ' '.join(path) else 'frozen', self.params | |
| ) | |
| else: | |
| param_partitions = traverse_util.path_aware_map( | |
| lambda *_: 'trainable', self.params | |
| ) | |
| self.only_temporal = only_temporal | |
| self.param_partitions: FrozenDict[str, Any] = FrozenDict(param_partitions) | |
| self.log(f'Total parameters: {count_params(self.params)}') | |
| self.log(f'Temporal parameters: {count_params(self.params, "temporal")}') | |
| def _load_inference_models(self) -> None: | |
| assert jax.process_index() == 0, 'not main process' | |
| if self.text_encoder is None: | |
| self.log('Load text encoder') | |
| self.text_encoder = CLIPTextModel.from_pretrained( | |
| self.pretrained_model, | |
| subfolder = 'text_encoder' | |
| ) | |
| if self.tokenizer is None: | |
| self.log('Load tokenizer') | |
| self.tokenizer = CLIPTokenizer.from_pretrained( | |
| self.pretrained_model, | |
| subfolder = 'tokenizer' | |
| ) | |
| if self.vae is None: | |
| self.log('Load vae') | |
| self.vae = AutoencoderKL.from_pretrained( | |
| self.pretrained_model, | |
| subfolder = 'vae' | |
| ) | |
| if self.ddim is None: | |
| self.log('Load ddim scheduler') | |
| # tuple(scheduler , scheduler state) | |
| self.ddim = FlaxDDIMScheduler.from_pretrained( | |
| self.pretrained_model, | |
| subfolder = 'scheduler', | |
| from_pt = True | |
| ) | |
| def _unload_inference_models(self) -> None: | |
| self.text_encoder = None | |
| self.tokenizer = None | |
| self.vae = None | |
| self.ddim = None | |
| def sample(self, | |
| params: Union[Dict[str, Any], FrozenDict[str, Any]], | |
| prompt: str, | |
| image_path: str, | |
| num_frames: int, | |
| replicate_params: bool = True, | |
| neg_prompt: str = '', | |
| steps: int = 50, | |
| cfg: float = 9.0, | |
| unload_after_usage: bool = False | |
| ) -> List[Image.Image]: | |
| assert jax.process_index() == 0, 'not main process' | |
| self.log('Sample') | |
| self._load_inference_models() | |
| with torch.no_grad(): | |
| tokens = self.tokenizer( | |
| [ prompt ], | |
| truncation = True, | |
| return_overflowing_tokens = False, | |
| padding = 'max_length', | |
| return_tensors = 'pt' | |
| ).input_ids | |
| neg_tokens = self.tokenizer( | |
| [ neg_prompt ], | |
| truncation = True, | |
| return_overflowing_tokens = False, | |
| padding = 'max_length', | |
| return_tensors = 'pt' | |
| ).input_ids | |
| encoded_prompt = self.text_encoder(input_ids = tokens).last_hidden_state | |
| encoded_neg_prompt = self.text_encoder(input_ids = neg_tokens).last_hidden_state | |
| hint_latent = torch.tensor(np.asarray(Image.open(image_path))).permute(2,0,1).to(torch.float32).div(255).mul(2).sub(1).unsqueeze(0) | |
| hint_latent = self.vae.encode(hint_latent).latent_dist.mean * self.vae.config.scaling_factor #0.18215 # deterministic | |
| hint_latent = hint_latent.unsqueeze(2).repeat_interleave(num_frames, 2) | |
| mask = torch.zeros_like(hint_latent[:,0:1,:,:,:]) # zero mask, e.g. skip masking for now | |
| init_latent = torch.randn_like(hint_latent) | |
| # move to devices | |
| encoded_prompt = jnp.array(encoded_prompt.numpy()) | |
| encoded_neg_prompt = jnp.array(encoded_neg_prompt.numpy()) | |
| hint_latent = jnp.array(hint_latent.numpy()) | |
| mask = jnp.array(mask.numpy()) | |
| init_latent = init_latent.repeat(jax.device_count(), 1, 1, 1, 1) | |
| init_latent = jnp.array(init_latent.numpy()) | |
| self.ddim = (self.ddim[0], self.ddim[0].set_timesteps(self.ddim[1], steps)) | |
| timesteps = self.ddim[1].timesteps | |
| if replicate_params: | |
| params = jax_utils.replicate(params) | |
| ddim_state = jax_utils.replicate(self.ddim[1]) | |
| encoded_prompt = jax_utils.replicate(encoded_prompt) | |
| encoded_neg_prompt = jax_utils.replicate(encoded_neg_prompt) | |
| hint_latent = jax_utils.replicate(hint_latent) | |
| mask = jax_utils.replicate(mask) | |
| # sampling fun | |
| def sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask): | |
| latent_model_input = jnp.concatenate([init_latent, mask, hint_latent], axis = 1) | |
| pred = self.model.apply( | |
| { 'params': params }, | |
| latent_model_input, | |
| t, | |
| encoded_prompt | |
| ).sample | |
| if cfg != 1.0: | |
| neg_pred = self.model.apply( | |
| { 'params': params }, | |
| latent_model_input, | |
| t, | |
| encoded_neg_prompt | |
| ).sample | |
| pred = neg_pred + cfg * (pred - neg_pred) | |
| # TODO check if noise is added at the right dimension | |
| init_latent, ddim_state = self.ddim[0].step(ddim_state, pred, t, init_latent).to_tuple() | |
| return init_latent, ddim_state | |
| p_sample_loop = jax.pmap(sample_loop, 'sample', donate_argnums = ()) | |
| pbar_sample = trange(len(timesteps), desc = 'Sample', dynamic_ncols = True, smoothing = 0.1, disable = not self.verbose) | |
| init_latent = shard(init_latent) | |
| for i in pbar_sample: | |
| t = timesteps[i].repeat(self.num_devices) | |
| t = shard(t) | |
| init_latent, ddim_state = p_sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask) | |
| # decode | |
| self.log('Decode') | |
| init_latent = torch.tensor(np.array(init_latent)) | |
| init_latent = init_latent / self.vae.config.scaling_factor | |
| # d:0 b:1 c:2 f:3 h:4 w:5 -> d b f c h w | |
| init_latent = init_latent.permute(0, 1, 3, 2, 4, 5) | |
| images = [] | |
| pbar_decode = trange(len(init_latent), desc = 'Decode', dynamic_ncols = True) | |
| for sample in init_latent: | |
| ims = self.vae.decode(sample.squeeze()).sample | |
| ims = ims.add(1).div(2).mul(255).round().clamp(0, 255).to(torch.uint8).permute(0,2,3,1).numpy() | |
| ims = [ Image.fromarray(x) for x in ims ] | |
| for im in ims: | |
| images.append(im) | |
| pbar_decode.update(1) | |
| if unload_after_usage: | |
| self._unload_inference_models() | |
| return images | |
| def get_params_from_state(self, state: TrainState) -> FrozenDict[Any, str]: | |
| return FrozenDict(jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))) | |
| def train(self, | |
| dataloader: DataLoader, | |
| lr: float, | |
| num_frames: int, | |
| log_every_step: int = 10, | |
| save_every_epoch: int = 1, | |
| sample_every_epoch: int = 1, | |
| output_dir: str = 'output', | |
| warmup: float = 0, | |
| decay: float = 0, | |
| epochs: int = 10, | |
| weight_decay: float = 1e-2 | |
| ) -> None: | |
| eps = 1e-8 | |
| total_steps = len(dataloader) * epochs | |
| warmup_steps = math.ceil(warmup * total_steps) if warmup > 0 else 0 | |
| decay_steps = math.ceil(decay * total_steps) + warmup_steps if decay > 0 else warmup_steps + 1 | |
| self.log(f'Total steps: {total_steps}') | |
| self.log(f'Warmup steps: {warmup_steps}') | |
| self.log(f'Decay steps: {decay_steps - warmup_steps}') | |
| if warmup > 0 or decay > 0: | |
| if not decay > 0: | |
| # only warmup, keep peak lr until end | |
| self.log('Warmup schedule') | |
| end_lr = lr | |
| else: | |
| # warmup + annealing to end lr | |
| self.log('Warmup + cosine annealing schedule') | |
| end_lr = eps | |
| lr_schedule = optax.warmup_cosine_decay_schedule( | |
| init_value = 0.0, | |
| peak_value = lr, | |
| warmup_steps = warmup_steps, | |
| decay_steps = decay_steps, | |
| end_value = end_lr | |
| ) | |
| else: | |
| # no warmup or decay -> constant lr | |
| self.log('constant schedule') | |
| lr_schedule = optax.constant_schedule(value = lr) | |
| adamw = optax.adamw( | |
| learning_rate = lr_schedule, | |
| b1 = 0.9, | |
| b2 = 0.999, | |
| eps = eps, | |
| weight_decay = weight_decay #0.01 # 0.0001 | |
| ) | |
| optim = optax.chain( | |
| optax.clip_by_global_norm(max_norm = 1.0), | |
| adamw | |
| ) | |
| partition_optimizers = { | |
| 'trainable': optim, | |
| 'frozen': optax.set_to_zero() | |
| } | |
| tx = optax.multi_transform(partition_optimizers, self.param_partitions) | |
| state = TrainState.create( | |
| apply_fn = self.model.__call__, | |
| params = self.params, | |
| tx = tx | |
| ) | |
| validation_rng, train_rngs = jax.random.split(self.rng) | |
| train_rngs = jax.random.split(train_rngs, jax.local_device_count()) | |
| def train_step(state: TrainState, batch: Dict[str, jax.Array], train_rng: jax.random.PRNGKeyArray): | |
| def compute_loss( | |
| params: Dict[str, Any], | |
| batch: Dict[str, jax.Array], | |
| sample_rng: jax.random.PRNGKeyArray # unused, dataloader provides everything | |
| ) -> jax.Array: | |
| # 'latent_model_input': latent_model_input | |
| # 'encoder_hidden_states': encoder_hidden_states | |
| # 'timesteps': timesteps | |
| # 'noise': noise | |
| latent_model_input = batch['latent_model_input'] | |
| encoder_hidden_states = batch['encoder_hidden_states'] | |
| timesteps = batch['timesteps'] | |
| noise = batch['noise'] | |
| model_pred = self.model.apply( | |
| { 'params': params }, | |
| latent_model_input, | |
| timesteps, | |
| encoder_hidden_states | |
| ).sample | |
| loss = (noise - model_pred) ** 2 | |
| loss = loss.mean() | |
| return loss | |
| grad_fn = jax.value_and_grad(compute_loss) | |
| def loss_and_grad( | |
| train_rng: jax.random.PRNGKeyArray | |
| ) -> Tuple[jax.Array, Any, jax.random.PRNGKeyArray]: | |
| sample_rng, train_rng = jax.random.split(train_rng, 2) | |
| loss, grad = grad_fn(state.params, batch, sample_rng) | |
| return loss, grad, train_rng | |
| loss, grad, new_train_rng = loss_and_grad(train_rng) | |
| # self.log(grad) # NOTE uncomment to visualize gradient | |
| grad = jax.lax.pmean(grad, axis_name = 'batch') | |
| new_state = state.apply_gradients(grads = grad) | |
| metrics: Dict[str, Any] = { 'loss': loss } | |
| metrics = jax.lax.pmean(metrics, axis_name = 'batch') | |
| def l2(xs) -> jax.Array: | |
| return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)])) | |
| metrics['l2_grads'] = l2(jax.tree_util.tree_leaves(grad)) | |
| return new_state, metrics, new_train_rng | |
| p_train_step = jax.pmap(fun = train_step, axis_name = 'batch', donate_argnums = (0, )) | |
| state = jax_utils.replicate(state) | |
| train_metrics = [] | |
| train_metric = None | |
| global_step: int = 0 | |
| if jax.process_index() == 0: | |
| self._init_tracker_meta() | |
| hyper_params = { | |
| 'lr': lr, | |
| 'lr_warmup': warmup, | |
| 'lr_decay': decay, | |
| 'weight_decay': weight_decay, | |
| 'total_steps': total_steps, | |
| 'batch_size': dataloader.batch_size // self.num_devices, | |
| 'num_frames': num_frames, | |
| 'sample_size': self.sample_size, | |
| 'num_devices': self.num_devices, | |
| 'seed': self.seed, | |
| 'use_memory_efficient_attention': self.model.use_memory_efficient_attention, | |
| 'only_temporal': self.only_temporal, | |
| 'dtype': self.dtype_str, | |
| 'param_dtype': self.param_dtype, | |
| 'pretrained_model': self.pretrained_model, | |
| 'model_config': self.model.config | |
| } | |
| if self._use_wandb: | |
| self.log('Setting up wandb') | |
| self._setup_wandb(hyper_params) | |
| self.log(hyper_params) | |
| output_path = os.path.join(output_dir, str(global_step), 'unet') | |
| self.log(f'saving checkpoint to {output_path}') | |
| self.model.save_pretrained( | |
| save_directory = output_path, | |
| params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)), | |
| is_main_process = True | |
| ) | |
| pbar_epoch = tqdm( | |
| total = epochs, | |
| desc = 'Epochs', | |
| smoothing = 1, | |
| position = 0, | |
| dynamic_ncols = True, | |
| leave = True, | |
| disable = jax.process_index() > 0 | |
| ) | |
| steps_per_epoch = len(dataloader) # TODO dataloader | |
| for epoch in range(epochs): | |
| pbar_steps = tqdm( | |
| total = steps_per_epoch, | |
| desc = 'Steps', | |
| position = 1, | |
| smoothing = 0.1, | |
| dynamic_ncols = True, | |
| leave = True, | |
| disable = jax.process_index() > 0 | |
| ) | |
| for batch in dataloader: | |
| # keep input + gt as float32, results in fp32 loss and grad | |
| # otherwise uncomment the following to cast to the model dtype | |
| # batch = { k: (v.astype(self.dtype) if v.dtype == np.float32 else v) for k,v in batch.items() } | |
| batch = shard(batch) | |
| state, train_metric, train_rngs = p_train_step( | |
| state, batch, train_rngs | |
| ) | |
| train_metrics.append(train_metric) | |
| if global_step % log_every_step == 0 and jax.process_index() == 0: | |
| train_metrics = jax_utils.unreplicate(train_metrics) | |
| train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics) | |
| if global_step == 0: | |
| self.log(f'grad dtype: {train_metrics["l2_grads"].dtype}') | |
| self.log(f'loss dtype: {train_metrics["loss"].dtype}') | |
| train_metrics_dict = { k: v.item() for k, v in train_metrics.items() } | |
| train_metrics_dict['lr'] = lr_schedule(global_step).item() | |
| self.log_metrics(train_metrics_dict, step = global_step, epoch = epoch) | |
| train_metrics = [] | |
| pbar_steps.update(1) | |
| global_step += 1 | |
| if epoch % save_every_epoch == 0 and jax.process_index() == 0: | |
| output_path = os.path.join(output_dir, str(global_step), 'unet') | |
| self.log(f'saving checkpoint to {output_path}') | |
| self.model.save_pretrained( | |
| save_directory = output_path, | |
| params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)), | |
| is_main_process = True | |
| ) | |
| self.log(f'checkpoint saved ') | |
| if epoch % sample_every_epoch == 0 and jax.process_index() == 0: | |
| images = self.sample( | |
| params = state.params, | |
| replicate_params = False, | |
| prompt = 'dancing person', | |
| image_path = 'testimage.png', | |
| num_frames = num_frames, | |
| steps = 50, | |
| cfg = 9.0, | |
| unload_after_usage = False | |
| ) | |
| os.makedirs(os.path.join('image_output', str(epoch)), exist_ok = True) | |
| for i, im in enumerate(images): | |
| im.save(os.path.join('image_output', str(epoch), str(i).zfill(5) + '.png'), optimize = True) | |
| pbar_epoch.update(1) | |