Spaces:
Running
Running
| import os | |
| import sys | |
| sys.path.append("..") | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import hydra | |
| import omegaconf | |
| import jax | |
| import jax.numpy as jnp | |
| import optax | |
| from flax.training.train_state import TrainState | |
| from flax.serialization import from_bytes | |
| from huggingface_hub import snapshot_download | |
| # lpn imports | |
| from src.models.lpn import LPN | |
| from src.models.transformer import EncoderTransformer, DecoderTransformer | |
| from src.visualization import display_grid | |
| from utils import patch_target, ax_to_pil | |
| checkpoint_name = "quiet-thunder-789--checkpoint:v0" | |
| BLUE_LOCATION_INPUTS = {0: 13, 1: 9} | |
| local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*") | |
| with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f: | |
| cfg = omegaconf.OmegaConf.load(f) | |
| patch_target(cfg) | |
| encoder = EncoderTransformer(hydra.utils.instantiate(cfg.encoder_transformer)) | |
| decoder = DecoderTransformer(hydra.utils.instantiate(cfg.decoder_transformer)) | |
| lpn = LPN(encoder=encoder, decoder=decoder) | |
| key = jax.random.PRNGKey(0) | |
| grids = jax.random.randint( | |
| key, | |
| (1, 3, decoder.config.max_rows, decoder.config.max_cols, 2), | |
| minval=0, | |
| maxval=decoder.config.vocab_size, | |
| ) | |
| shapes = jax.random.randint( | |
| key, | |
| (1, 3, 2, 2), | |
| minval=1, | |
| maxval=min(decoder.config.max_rows, decoder.config.max_cols) + 1, | |
| ) | |
| variables = lpn.init( | |
| key, grids, shapes, dropout_eval=False, prior_kl_coeff=0.0, pairwise_kl_coeff=0.0, mode="mean" | |
| ) | |
| learning_rate, linear_warmup_steps = 0, 0 | |
| linear_warmup_scheduler = optax.warmup_exponential_decay_schedule( | |
| init_value=learning_rate / (linear_warmup_steps + 1), | |
| peak_value=learning_rate, | |
| warmup_steps=linear_warmup_steps, | |
| transition_steps=1, | |
| end_value=learning_rate, | |
| decay_rate=1.0, | |
| ) | |
| optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(linear_warmup_scheduler)) | |
| optimizer = optax.MultiSteps(optimizer, every_k_schedule=1) | |
| train_state = TrainState.create(apply_fn=lpn.apply, tx=optimizer, params=variables["params"]) | |
| with open(os.path.join(local_dir, checkpoint_name, "state.msgpack"), "rb") as data_file: | |
| byte_data = data_file.read() | |
| loaded_state = from_bytes(train_state, byte_data) | |
| generate_output_from_context = jax.jit( | |
| lambda context, input, input_grid_shape: lpn.apply( | |
| {"params": loaded_state.params}, | |
| context=context, | |
| input=input, | |
| input_grid_shape=input_grid_shape, | |
| dropout_eval=True, | |
| method=lpn._generate_output_from_context, | |
| ) | |
| ) | |
| def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Image.Image: | |
| # Create the input image | |
| input = jnp.zeros(16, int).at[BLUE_LOCATION_INPUTS[image_idx]].set(1).reshape(4, 4) | |
| # Ensure x and y are in [eps, 1 - eps] | |
| x = min(1 - eps, max(eps, x)) | |
| y = min(1 - eps, max(eps, y)) | |
| # Convert x and y to context in R^2 | |
| context = jax.scipy.stats.norm.ppf(jnp.array([x, y])) | |
| output_grids, _ = generate_output_from_context( | |
| context=context[None], input=input[None], input_grid_shape=jnp.array([4, 4])[None] | |
| ) | |
| output_grid = output_grids[0] | |
| _, ax = plt.subplots(1, 1, figsize=(4, 4)) | |
| display_grid(ax=ax, grid=output_grid, grid_shape=jnp.array([4, 4])) | |
| pil_image = ax_to_pil(ax) | |
| return pil_image | |