New logic
Browse files- app.py +72 -118
- model/__init__.py +2 -0
- model/convnext.py +55 -0
- model/edsr.py +122 -0
- model/hyper.py +41 -0
- model/init.py +24 -0
- model/rdn.py +72 -0
- model/swin_ir.py +532 -0
- model/tail.py +18 -0
- model/thera.py +175 -0
- requirements.txt +36 -5
- super_resolve.py +99 -0
- utils.py +36 -0
app.py
CHANGED
|
@@ -1,137 +1,91 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
| 5 |
-
from
|
| 6 |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
| 7 |
-
from
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
#
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
])
|
| 53 |
-
|
| 54 |
-
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
|
| 55 |
-
with torch.no_grad():
|
| 56 |
-
output = model(input_tensor)
|
| 57 |
-
|
| 58 |
-
output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(-1, 1) * 0.5 + 0.5)
|
| 59 |
-
return output_img
|
| 60 |
-
|
| 61 |
|
| 62 |
-
|
| 63 |
-
inputs = feature_extractor(
|
| 64 |
with torch.no_grad():
|
| 65 |
-
outputs =
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
prediction = torch.nn.functional.interpolate(
|
| 69 |
-
predicted_depth.unsqueeze(1),
|
| 70 |
-
size=image.size[::-1],
|
| 71 |
-
mode="bicubic",
|
| 72 |
-
align_corners=False,
|
| 73 |
-
)
|
| 74 |
-
return prediction.squeeze().cpu().numpy()
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def create_bas_relief(prompt, image, depth_map, pipe):
|
| 78 |
-
control_image = Image.fromarray((depth_map * 255).astype(np.uint8))
|
| 79 |
-
|
| 80 |
-
image = image.resize((1024, 1024))
|
| 81 |
-
control_image = control_image.resize((1024, 1024))
|
| 82 |
-
|
| 83 |
-
result = pipe(
|
| 84 |
-
prompt=prompt,
|
| 85 |
-
image=image,
|
| 86 |
-
control_image=control_image,
|
| 87 |
-
strength=0.8,
|
| 88 |
-
num_inference_steps=30
|
| 89 |
-
).images[0]
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
|
|
|
| 93 |
|
| 94 |
-
# --- Interface Gradio ---
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
|
|
|
| 98 |
|
| 99 |
with gr.Row():
|
| 100 |
with gr.Column():
|
| 101 |
-
|
| 102 |
-
prompt = gr.Textbox("
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
with gr.Column():
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
depth_model = load_depth_model()
|
| 115 |
-
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
| 116 |
-
basrelief_pipe = load_controlnet()
|
| 117 |
-
|
| 118 |
-
# 1. Super Resolução
|
| 119 |
-
upscaled = run_thera(image, thera_model)
|
| 120 |
-
|
| 121 |
-
# 2. Depth Map
|
| 122 |
-
depth = create_depth_map(upscaled, depth_model, feature_extractor)
|
| 123 |
-
depth_normalized = (depth - depth.min()) / (depth.max() - depth.min())
|
| 124 |
-
|
| 125 |
-
# 3. Bas-Relief
|
| 126 |
-
basrelief = create_bas_relief(prompt, upscaled, depth_normalized, basrelief_pipe)
|
| 127 |
-
|
| 128 |
-
return upscaled, depth_normalized, basrelief
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
submit_btn.click(
|
| 132 |
-
process,
|
| 133 |
-
inputs=[input_image, prompt],
|
| 134 |
-
outputs=[upscaled_output, depth_output, basrelief_output]
|
| 135 |
)
|
| 136 |
|
| 137 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
+
import jax
|
| 4 |
import numpy as np
|
| 5 |
from PIL import Image
|
| 6 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline
|
| 7 |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
| 8 |
+
from super_resolve import process as thera_process # Assume imports do Thera
|
| 9 |
+
|
| 10 |
+
# Configurações
|
| 11 |
+
DEVICE = "cpu" # ou "cuda" se disponível
|
| 12 |
+
JAX_DEVICE = jax.devices("cpu")[0] # Usar CPU para JAX
|
| 13 |
+
|
| 14 |
+
# 1. Carregar modelos do Thera (EDSR/RDN)
|
| 15 |
+
# (Implementar conforme código original do Thera)
|
| 16 |
+
model_edsr, params_edsr = None, None # Carregar usando pickle/HF Hub
|
| 17 |
+
|
| 18 |
+
# 2. Carregar SDXL Img2Img + LoRA
|
| 19 |
+
print("Carregando SDXL Img2Img com LoRA...")
|
| 20 |
+
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
| 21 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 22 |
+
torch_dtype=torch.float32
|
| 23 |
+
).to(DEVICE)
|
| 24 |
+
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
|
| 25 |
+
|
| 26 |
+
# 3. Carregar modelo de profundidade
|
| 27 |
+
print("Carregando DPT...")
|
| 28 |
+
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
| 29 |
+
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(DEVICE)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def enhance_depth_map(depth_arr):
|
| 33 |
+
depth_normalized = (depth_arr - depth_arr.min()) / (depth_arr.max() - depth_arr.min() + 1e-8)
|
| 34 |
+
return Image.fromarray((depth_normalized * 255).astype(np.uint8))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def full_pipeline(image, prompt, scale_factor=2.0):
|
| 38 |
+
# 1. Super Resolução com Thera
|
| 39 |
+
source = np.array(image) / 255.0
|
| 40 |
+
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
|
| 41 |
+
upscaled = thera_process(source, model_edsr, params_edsr, target_shape, do_ensemble=True)
|
| 42 |
+
upscaled_pil = Image.fromarray((upscaled * 255).astype(np.uint8))
|
| 43 |
+
|
| 44 |
+
# 2. Gerar Bas-Relief com SDXL Img2Img
|
| 45 |
+
full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
|
| 46 |
+
bas_relief = pipe(
|
| 47 |
+
prompt=full_prompt,
|
| 48 |
+
image=upscaled_pil,
|
| 49 |
+
strength=0.7,
|
| 50 |
+
num_inference_steps=25,
|
| 51 |
+
guidance_scale=7.5
|
| 52 |
+
).images[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
# 3. Calcular Depth Map
|
| 55 |
+
inputs = feature_extractor(bas_relief, return_tensors="pt").to(DEVICE)
|
| 56 |
with torch.no_grad():
|
| 57 |
+
outputs = depth_model(**inputs)
|
| 58 |
+
depth = outputs.predicted_depth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
depth_map = torch.nn.functional.interpolate(
|
| 61 |
+
depth.unsqueeze(1),
|
| 62 |
+
size=bas_relief.size[::-1],
|
| 63 |
+
mode="bicubic"
|
| 64 |
+
).squeeze().cpu().numpy()
|
| 65 |
|
| 66 |
+
return upscaled_pil, bas_relief, enhance_depth_map(depth_map)
|
| 67 |
|
|
|
|
| 68 |
|
| 69 |
+
# Interface Gradio
|
| 70 |
+
with gr.Blocks(title="Super Resolução + Bas-Relief") as app:
|
| 71 |
+
gr.Markdown("## 📈 Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
|
| 72 |
|
| 73 |
with gr.Row():
|
| 74 |
with gr.Column():
|
| 75 |
+
img_input = gr.Image(type="pil", label="Imagem de Entrada")
|
| 76 |
+
prompt = gr.Textbox("ancient sculpture, marble", label="Descrição do Relevo")
|
| 77 |
+
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
|
| 78 |
+
btn = gr.Button("Processar")
|
| 79 |
|
| 80 |
with gr.Column():
|
| 81 |
+
img_upscaled = gr.Image(label="Imagem Super Resolvida")
|
| 82 |
+
img_basrelief = gr.Image(label="Relevo Escultural")
|
| 83 |
+
img_depth = gr.Image(label="Mapa de Profundidade")
|
| 84 |
+
|
| 85 |
+
btn.click(
|
| 86 |
+
full_pipeline,
|
| 87 |
+
inputs=[img_input, prompt, scale],
|
| 88 |
+
outputs=[img_upscaled, img_basrelief, img_depth]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
)
|
| 90 |
|
| 91 |
if __name__ == "__main__":
|
model/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .hyper import Hypernetwork
|
| 2 |
+
from .thera import build_thera
|
model/convnext.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import flax.linen as nn
|
| 2 |
+
from jaxtyping import Array, ArrayLike
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ConvNeXtBlock(nn.Module):
|
| 6 |
+
"""ConvNext block. See Fig.4 in "A ConvNet for the 2020s" by Liu et al.
|
| 7 |
+
|
| 8 |
+
https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_A_ConvNet_for_the_2020s_CVPR_2022_paper.pdf
|
| 9 |
+
"""
|
| 10 |
+
n_dims: int = 64
|
| 11 |
+
kernel_size: int = 3 # 7 in the paper's version
|
| 12 |
+
group_features: bool = False
|
| 13 |
+
|
| 14 |
+
def setup(self) -> None:
|
| 15 |
+
self.residual = nn.Sequential([
|
| 16 |
+
nn.Conv(self.n_dims, kernel_size=(self.kernel_size, self.kernel_size), use_bias=False,
|
| 17 |
+
feature_group_count=self.n_dims if self.group_features else 1),
|
| 18 |
+
nn.LayerNorm(),
|
| 19 |
+
nn.Conv(4 * self.n_dims, kernel_size=(1, 1)),
|
| 20 |
+
nn.gelu,
|
| 21 |
+
nn.Conv(self.n_dims, kernel_size=(1, 1)),
|
| 22 |
+
])
|
| 23 |
+
|
| 24 |
+
def __call__(self, x: ArrayLike) -> Array:
|
| 25 |
+
return x + self.residual(x)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Projection(nn.Module):
|
| 29 |
+
n_dims: int
|
| 30 |
+
|
| 31 |
+
@nn.compact
|
| 32 |
+
def __call__(self, x: ArrayLike) -> Array:
|
| 33 |
+
x = nn.LayerNorm()(x)
|
| 34 |
+
x = nn.Conv(self.n_dims, (1, 1))(x)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ConvNeXt(nn.Module):
|
| 39 |
+
block_defs: list[tuple]
|
| 40 |
+
|
| 41 |
+
def setup(self) -> None:
|
| 42 |
+
layers = []
|
| 43 |
+
current_size = self.block_defs[0][0]
|
| 44 |
+
for block_def in self.block_defs:
|
| 45 |
+
if block_def[0] != current_size:
|
| 46 |
+
layers.append(Projection(block_def[0]))
|
| 47 |
+
layers.append(ConvNeXtBlock(*block_def))
|
| 48 |
+
current_size = block_def[0]
|
| 49 |
+
self.layers = layers
|
| 50 |
+
|
| 51 |
+
def __call__(self, x: ArrayLike, _: bool) -> Array:
|
| 52 |
+
for layer in self.layers:
|
| 53 |
+
x = layer(x)
|
| 54 |
+
return x
|
| 55 |
+
|
model/edsr.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from https://github.com/isaaccorley/jax-enhance
|
| 2 |
+
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Sequence, Callable
|
| 5 |
+
|
| 6 |
+
import jax.numpy as jnp
|
| 7 |
+
import flax.linen as nn
|
| 8 |
+
from flax.core.frozen_dict import freeze
|
| 9 |
+
import einops
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PixelShuffle(nn.Module):
|
| 13 |
+
scale_factor: int
|
| 14 |
+
|
| 15 |
+
def setup(self):
|
| 16 |
+
self.layer = partial(
|
| 17 |
+
einops.rearrange,
|
| 18 |
+
pattern="b h w (c h2 w2) -> b (h h2) (w w2) c",
|
| 19 |
+
h2=self.scale_factor,
|
| 20 |
+
w2=self.scale_factor
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
| 24 |
+
return self.layer(x)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ResidualBlock(nn.Module):
|
| 28 |
+
channels: int
|
| 29 |
+
kernel_size: Sequence[int]
|
| 30 |
+
res_scale: float
|
| 31 |
+
activation: Callable
|
| 32 |
+
dtype: Any = jnp.float32
|
| 33 |
+
|
| 34 |
+
def setup(self):
|
| 35 |
+
self.body = nn.Sequential([
|
| 36 |
+
nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype),
|
| 37 |
+
self.activation,
|
| 38 |
+
nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype),
|
| 39 |
+
])
|
| 40 |
+
|
| 41 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
| 42 |
+
return x + self.body(x)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class UpsampleBlock(nn.Module):
|
| 46 |
+
num_upsamples: int
|
| 47 |
+
channels: int
|
| 48 |
+
kernel_size: Sequence[int]
|
| 49 |
+
dtype: Any = jnp.float32
|
| 50 |
+
|
| 51 |
+
def setup(self):
|
| 52 |
+
layers = []
|
| 53 |
+
for _ in range(self.num_upsamples):
|
| 54 |
+
layers.extend([
|
| 55 |
+
nn.Conv(features=self.channels * 2 ** 2, kernel_size=self.kernel_size, dtype=self.dtype),
|
| 56 |
+
PixelShuffle(scale_factor=2),
|
| 57 |
+
])
|
| 58 |
+
self.layers = layers
|
| 59 |
+
|
| 60 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
| 61 |
+
for layer in self.layers:
|
| 62 |
+
x = layer(x)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class EDSR(nn.Module):
|
| 67 |
+
"""Enhanced Deep Residual Networks for Single Image Super-Resolution https://arxiv.org/pdf/1707.02921v1.pdf"""
|
| 68 |
+
scale_factor: int
|
| 69 |
+
channels: int = 3
|
| 70 |
+
num_blocks: int = 32
|
| 71 |
+
num_feats: int = 256
|
| 72 |
+
dtype: Any = jnp.float32
|
| 73 |
+
|
| 74 |
+
def setup(self):
|
| 75 |
+
# pre res blocks layer
|
| 76 |
+
self.head = nn.Sequential([nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype)])
|
| 77 |
+
|
| 78 |
+
# res blocks
|
| 79 |
+
res_blocks = [
|
| 80 |
+
ResidualBlock(channels=self.num_feats, kernel_size=(3, 3), res_scale=0.1, activation=nn.relu, dtype=self.dtype)
|
| 81 |
+
for i in range(self.num_blocks)
|
| 82 |
+
]
|
| 83 |
+
res_blocks.append(nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype))
|
| 84 |
+
self.body = nn.Sequential(res_blocks)
|
| 85 |
+
|
| 86 |
+
def __call__(self, x: jnp.ndarray, _=None) -> jnp.ndarray:
|
| 87 |
+
x = self.head(x)
|
| 88 |
+
x = x + self.body(x)
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def convert_edsr_checkpoint(torch_dict, no_upsampling=True):
|
| 93 |
+
def convert(in_dict):
|
| 94 |
+
top_keys = set([k.split('.')[0] for k in in_dict.keys()])
|
| 95 |
+
leaves = set([k for k in in_dict.keys() if '.' not in k])
|
| 96 |
+
|
| 97 |
+
# convert leaves
|
| 98 |
+
out_dict = {}
|
| 99 |
+
for l in leaves:
|
| 100 |
+
if l == 'weight':
|
| 101 |
+
out_dict['kernel'] = jnp.asarray(in_dict[l]).transpose((2, 3, 1, 0))
|
| 102 |
+
elif l == 'bias':
|
| 103 |
+
out_dict[l] = jnp.asarray(in_dict[l])
|
| 104 |
+
else:
|
| 105 |
+
out_dict[l] = in_dict[l]
|
| 106 |
+
|
| 107 |
+
for top_key in top_keys.difference(leaves):
|
| 108 |
+
new_top_key = 'layers_' + top_key if top_key.isdigit() else top_key
|
| 109 |
+
out_dict[new_top_key] = convert(
|
| 110 |
+
{k[len(top_key) + 1:]: v for k, v in in_dict.items() if k.startswith(top_key)})
|
| 111 |
+
return out_dict
|
| 112 |
+
|
| 113 |
+
converted = convert(torch_dict)
|
| 114 |
+
|
| 115 |
+
# remove unwanted keys
|
| 116 |
+
if no_upsampling:
|
| 117 |
+
del converted['tail']
|
| 118 |
+
|
| 119 |
+
for k in ('add_mean', 'sub_mean'):
|
| 120 |
+
del converted[k]
|
| 121 |
+
|
| 122 |
+
return freeze(converted)
|
model/hyper.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
import jax.numpy as jnp
|
| 5 |
+
import flax.linen as nn
|
| 6 |
+
from jaxtyping import Array, ArrayLike, PyTreeDef
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from utils import interpolate_grid
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Hypernetwork(nn.Module):
|
| 13 |
+
encoder: nn.Module
|
| 14 |
+
refine: nn.Module
|
| 15 |
+
output_params_shape: list[tuple] # e.g. [(16,), (32, 32), ...]
|
| 16 |
+
tree_def: PyTreeDef # used to reconstruct the parameter sets
|
| 17 |
+
|
| 18 |
+
def setup(self):
|
| 19 |
+
# one layer 1x1 conv to calculate field params, as in SIREN paper
|
| 20 |
+
output_size = sum(math.prod(s) for s in self.output_params_shape)
|
| 21 |
+
self.out_conv = nn.Conv(output_size, kernel_size=(1, 1), use_bias=True)
|
| 22 |
+
|
| 23 |
+
def get_encoding(self, source: ArrayLike, training=False) -> Array:
|
| 24 |
+
"""Convenience method for whole-image evaluation"""
|
| 25 |
+
return self.refine(self.encoder(source, training), training)
|
| 26 |
+
|
| 27 |
+
def get_params_at_coords(self, encoding: ArrayLike, coords: ArrayLike) -> Array:
|
| 28 |
+
encoding = interpolate_grid(coords, encoding)
|
| 29 |
+
phi_params = self.out_conv(encoding)
|
| 30 |
+
|
| 31 |
+
# reshape to output params shape
|
| 32 |
+
phi_params = jnp.split(
|
| 33 |
+
phi_params, np.cumsum([math.prod(s) for s in self.output_params_shape[:-1]]), axis=-1)
|
| 34 |
+
phi_params = [jnp.reshape(p, p.shape[:-1] + s) for p, s in
|
| 35 |
+
zip(phi_params, self.output_params_shape)]
|
| 36 |
+
|
| 37 |
+
return jax.tree_util.tree_unflatten(self.tree_def, phi_params)
|
| 38 |
+
|
| 39 |
+
def __call__(self, source: ArrayLike, target_coords: ArrayLike, training=False) -> Array:
|
| 40 |
+
encoding = self.get_encoding(source, training)
|
| 41 |
+
return self.get_params_at_coords(encoding, target_coords)
|
model/init.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
import jax.numpy as jnp
|
| 5 |
+
from jaxtyping import Array
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def uniform_between(a: float, b: float, dtype=jnp.float32) -> Callable:
|
| 9 |
+
def init(key, shape, dtype=dtype) -> Array:
|
| 10 |
+
return jax.random.uniform(key, shape, dtype=dtype, minval=a, maxval=b)
|
| 11 |
+
return init
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def linear_up(scale: float) -> Callable:
|
| 15 |
+
def init(key, shape, dtype=jnp.float32) -> Array:
|
| 16 |
+
assert shape[-2] == 2
|
| 17 |
+
keys = jax.random.split(key, 2)
|
| 18 |
+
norm = jnp.pi * scale * (
|
| 19 |
+
jax.random.uniform(keys[0], shape=(1, shape[-1])) ** .5)
|
| 20 |
+
theta = 2 * jnp.pi * jax.random.uniform(keys[1], shape=(1, shape[-1]))
|
| 21 |
+
x = norm * jnp.cos(theta)
|
| 22 |
+
y = norm * jnp.sin(theta)
|
| 23 |
+
return jnp.concatenate([x, y], axis=-2).astype(dtype)
|
| 24 |
+
return init
|
model/rdn.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Residual Dense Network for Image Super-Resolution
|
| 2 |
+
# https://arxiv.org/abs/1802.08797
|
| 3 |
+
# modified from: https://github.com/thstkdgus35/EDSR-PyTorch
|
| 4 |
+
|
| 5 |
+
import jax.numpy as jnp
|
| 6 |
+
import flax.linen as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RDB_Conv(nn.Module):
|
| 10 |
+
growRate: int
|
| 11 |
+
kSize: int = 3
|
| 12 |
+
|
| 13 |
+
@nn.compact
|
| 14 |
+
def __call__(self, x):
|
| 15 |
+
out = nn.Sequential([
|
| 16 |
+
nn.Conv(self.growRate, (self.kSize, self.kSize), padding=(self.kSize-1)//2),
|
| 17 |
+
nn.activation.relu
|
| 18 |
+
])(x)
|
| 19 |
+
return jnp.concatenate((x, out), -1)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class RDB(nn.Module):
|
| 23 |
+
growRate0: int
|
| 24 |
+
growRate: int
|
| 25 |
+
nConvLayers: int
|
| 26 |
+
|
| 27 |
+
@nn.compact
|
| 28 |
+
def __call__(self, x):
|
| 29 |
+
res = x
|
| 30 |
+
|
| 31 |
+
for c in range(self.nConvLayers):
|
| 32 |
+
x = RDB_Conv(self.growRate)(x)
|
| 33 |
+
|
| 34 |
+
x = nn.Conv(self.growRate0, (1, 1))(x)
|
| 35 |
+
|
| 36 |
+
return x + res
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class RDN(nn.Module):
|
| 40 |
+
G0: int = 64
|
| 41 |
+
RDNkSize: int = 3
|
| 42 |
+
RDNconfig: str = 'B'
|
| 43 |
+
scale: int = 2
|
| 44 |
+
n_colors: int = 3
|
| 45 |
+
|
| 46 |
+
@nn.compact
|
| 47 |
+
def __call__(self, x, _=None):
|
| 48 |
+
D, C, G = {
|
| 49 |
+
'A': (20, 6, 32),
|
| 50 |
+
'B': (16, 8, 64),
|
| 51 |
+
}[self.RDNconfig]
|
| 52 |
+
|
| 53 |
+
# Shallow feature extraction
|
| 54 |
+
f_1 = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(x)
|
| 55 |
+
x = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(f_1)
|
| 56 |
+
|
| 57 |
+
# Redidual dense blocks and dense feature fusion
|
| 58 |
+
RDBs_out = []
|
| 59 |
+
for i in range(D):
|
| 60 |
+
x = RDB(self.G0, G, C)(x)
|
| 61 |
+
RDBs_out.append(x)
|
| 62 |
+
|
| 63 |
+
x = jnp.concatenate(RDBs_out, -1)
|
| 64 |
+
|
| 65 |
+
# Global Feature Fusion
|
| 66 |
+
x = nn.Sequential([
|
| 67 |
+
nn.Conv(self.G0, (1, 1)),
|
| 68 |
+
nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))
|
| 69 |
+
])(x)
|
| 70 |
+
|
| 71 |
+
x = x + f_1
|
| 72 |
+
return x
|
model/swin_ir.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Callable, Optional, Iterable
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import jax
|
| 6 |
+
import jax.numpy as jnp
|
| 7 |
+
import flax.linen as nn
|
| 8 |
+
from jaxtyping import Array
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def trunc_normal(mean=0., std=1., a=-2., b=2., dtype=jnp.float32) -> Callable:
|
| 12 |
+
"""Truncated normal initialization function"""
|
| 13 |
+
|
| 14 |
+
def init(key, shape, dtype=dtype) -> Array:
|
| 15 |
+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/weight_init.py
|
| 16 |
+
def norm_cdf(x):
|
| 17 |
+
# Computes standard normal cumulative distribution function
|
| 18 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 19 |
+
|
| 20 |
+
l = norm_cdf((a - mean) / std)
|
| 21 |
+
u = norm_cdf((b - mean) / std)
|
| 22 |
+
out = jax.random.uniform(key, shape, dtype=dtype, minval=2 * l - 1, maxval=2 * u - 1)
|
| 23 |
+
out = jax.scipy.special.erfinv(out) * std * math.sqrt(2.) + mean
|
| 24 |
+
return jnp.clip(out, a, b)
|
| 25 |
+
|
| 26 |
+
return init
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def Dense(features, use_bias=True, kernel_init=trunc_normal(std=.02), bias_init=nn.initializers.zeros):
|
| 30 |
+
return nn.Dense(features, use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def LayerNorm():
|
| 34 |
+
"""torch LayerNorm uses larger epsilon by default"""
|
| 35 |
+
return nn.LayerNorm(epsilon=1e-05)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Mlp(nn.Module):
|
| 39 |
+
|
| 40 |
+
in_features: int
|
| 41 |
+
hidden_features: int = None
|
| 42 |
+
out_features: int = None
|
| 43 |
+
act_layer: Callable = nn.gelu
|
| 44 |
+
drop: float = 0.0
|
| 45 |
+
|
| 46 |
+
@nn.compact
|
| 47 |
+
def __call__(self, x, training: bool):
|
| 48 |
+
x = nn.Dense(self.hidden_features or self.in_features)(x)
|
| 49 |
+
x = self.act_layer(x)
|
| 50 |
+
x = nn.Dropout(self.drop, deterministic=not training)(x)
|
| 51 |
+
x = nn.Dense(self.out_features or self.in_features)(x)
|
| 52 |
+
x = nn.Dropout(self.drop, deterministic=not training)(x)
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def window_partition(x, window_size: int):
|
| 57 |
+
"""
|
| 58 |
+
Args:
|
| 59 |
+
x: (B, H, W, C)
|
| 60 |
+
window_size (int): window size
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 64 |
+
"""
|
| 65 |
+
B, H, W, C = x.shape
|
| 66 |
+
x = x.reshape((B, H // window_size, window_size, W // window_size, window_size, C))
|
| 67 |
+
windows = x.transpose((0, 1, 3, 2, 4, 5)).reshape((-1, window_size, window_size, C))
|
| 68 |
+
return windows
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 75 |
+
window_size (int): Window size
|
| 76 |
+
H (int): Height of image
|
| 77 |
+
W (int): Width of image
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
x: (B, H, W, C)
|
| 81 |
+
"""
|
| 82 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 83 |
+
x = windows.reshape((B, H // window_size, W // window_size, window_size, window_size, -1))
|
| 84 |
+
x = x.transpose((0, 1, 3, 2, 4, 5)).reshape((B, H, W, -1))
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class DropPath(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
Implementation referred from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
dropout_prob: float = 0.1
|
| 94 |
+
deterministic: Optional[bool] = None
|
| 95 |
+
|
| 96 |
+
@nn.compact
|
| 97 |
+
def __call__(self, input, training):
|
| 98 |
+
if not training:
|
| 99 |
+
return input
|
| 100 |
+
keep_prob = 1 - self.dropout_prob
|
| 101 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
|
| 102 |
+
rng = self.make_rng("dropout")
|
| 103 |
+
random_tensor = keep_prob + jax.random.uniform(rng, shape)
|
| 104 |
+
random_tensor = jnp.floor(random_tensor)
|
| 105 |
+
return jnp.divide(input, keep_prob) * random_tensor
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class WindowAttention(nn.Module):
|
| 109 |
+
dim: int
|
| 110 |
+
window_size: Iterable[int]
|
| 111 |
+
num_heads: int
|
| 112 |
+
qkv_bias: bool = True
|
| 113 |
+
qk_scale: Optional[float] = None
|
| 114 |
+
att_drop: float = 0.0
|
| 115 |
+
proj_drop: float = 0.0
|
| 116 |
+
|
| 117 |
+
def make_rel_pos_index(self):
|
| 118 |
+
h_indices = np.arange(0, self.window_size[0])
|
| 119 |
+
w_indices = np.arange(0, self.window_size[1])
|
| 120 |
+
indices = np.stack(np.meshgrid(w_indices, h_indices, indexing="ij"))
|
| 121 |
+
flatten_indices = np.reshape(indices, (2, -1))
|
| 122 |
+
relative_indices = flatten_indices[:, :, None] - flatten_indices[:, None, :]
|
| 123 |
+
relative_indices = np.transpose(relative_indices, (1, 2, 0))
|
| 124 |
+
relative_indices[:, :, 0] += self.window_size[0] - 1
|
| 125 |
+
relative_indices[:, :, 1] += self.window_size[1] - 1
|
| 126 |
+
relative_indices[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 127 |
+
relative_pos_index = np.sum(relative_indices, -1)
|
| 128 |
+
return relative_pos_index
|
| 129 |
+
|
| 130 |
+
@nn.compact
|
| 131 |
+
def __call__(self, inputs, mask, training):
|
| 132 |
+
rpbt = self.param(
|
| 133 |
+
"relative_position_bias_table",
|
| 134 |
+
trunc_normal(std=.02),
|
| 135 |
+
(
|
| 136 |
+
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
|
| 137 |
+
self.num_heads,
|
| 138 |
+
),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
#relative_pos_index = self.variable(
|
| 142 |
+
# "variables", "relative_position_index", self.get_rel_pos_index
|
| 143 |
+
#)
|
| 144 |
+
|
| 145 |
+
batch, n, channels = inputs.shape
|
| 146 |
+
qkv = nn.Dense(self.dim * 3, use_bias=self.qkv_bias, name="qkv")(inputs)
|
| 147 |
+
qkv = qkv.reshape(batch, n, 3, self.num_heads, channels // self.num_heads)
|
| 148 |
+
qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
|
| 149 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 150 |
+
|
| 151 |
+
scale = self.qk_scale or (self.dim // self.num_heads) ** -0.5
|
| 152 |
+
q = q * scale
|
| 153 |
+
att = q @ jnp.swapaxes(k, -2, -1)
|
| 154 |
+
|
| 155 |
+
rel_pos_bias = jnp.reshape(
|
| 156 |
+
rpbt[np.reshape(self.make_rel_pos_index(), (-1))],
|
| 157 |
+
(
|
| 158 |
+
self.window_size[0] * self.window_size[1],
|
| 159 |
+
self.window_size[0] * self.window_size[1],
|
| 160 |
+
-1,
|
| 161 |
+
),
|
| 162 |
+
)
|
| 163 |
+
rel_pos_bias = jnp.transpose(rel_pos_bias, (2, 0, 1))
|
| 164 |
+
att += jnp.expand_dims(rel_pos_bias, 0)
|
| 165 |
+
|
| 166 |
+
if mask is not None:
|
| 167 |
+
att = jnp.reshape(
|
| 168 |
+
att, (batch // mask.shape[0], mask.shape[0], self.num_heads, n, n)
|
| 169 |
+
)
|
| 170 |
+
att = att + jnp.expand_dims(jnp.expand_dims(mask, 1), 0)
|
| 171 |
+
att = jnp.reshape(att, (-1, self.num_heads, n, n))
|
| 172 |
+
att = jax.nn.softmax(att)
|
| 173 |
+
|
| 174 |
+
else:
|
| 175 |
+
att = jax.nn.softmax(att)
|
| 176 |
+
|
| 177 |
+
att = nn.Dropout(self.att_drop)(att, deterministic=not training)
|
| 178 |
+
|
| 179 |
+
x = jnp.reshape(jnp.swapaxes(att @ v, 1, 2), (batch, n, channels))
|
| 180 |
+
x = nn.Dense(self.dim, name="proj")(x)
|
| 181 |
+
x = nn.Dropout(self.proj_drop)(x, deterministic=not training)
|
| 182 |
+
return x
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class SwinTransformerBlock(nn.Module):
|
| 186 |
+
|
| 187 |
+
dim: int
|
| 188 |
+
input_resolution: tuple[int]
|
| 189 |
+
num_heads: int
|
| 190 |
+
window_size: int = 7
|
| 191 |
+
shift_size: int = 0
|
| 192 |
+
mlp_ratio: float = 4.
|
| 193 |
+
qkv_bias: bool = True
|
| 194 |
+
qk_scale: Optional[float] = None
|
| 195 |
+
drop: float = 0.
|
| 196 |
+
attn_drop: float = 0.
|
| 197 |
+
drop_path: float = 0.
|
| 198 |
+
act_layer: Callable = nn.activation.gelu
|
| 199 |
+
norm_layer: Callable = LayerNorm
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def make_att_mask(shift_size, window_size, height, width):
|
| 203 |
+
if shift_size > 0:
|
| 204 |
+
mask = jnp.zeros([1, height, width, 1])
|
| 205 |
+
h_slices = (
|
| 206 |
+
slice(0, -window_size),
|
| 207 |
+
slice(-window_size, -shift_size),
|
| 208 |
+
slice(-shift_size, None),
|
| 209 |
+
)
|
| 210 |
+
w_slices = (
|
| 211 |
+
slice(0, -window_size),
|
| 212 |
+
slice(-window_size, -shift_size),
|
| 213 |
+
slice(-shift_size, None),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
count = 0
|
| 217 |
+
for h in h_slices:
|
| 218 |
+
for w in w_slices:
|
| 219 |
+
mask = mask.at[:, h, w, :].set(count)
|
| 220 |
+
count += 1
|
| 221 |
+
|
| 222 |
+
mask_windows = window_partition(mask, window_size)
|
| 223 |
+
mask_windows = jnp.reshape(mask_windows, (-1, window_size * window_size))
|
| 224 |
+
att_mask = jnp.expand_dims(mask_windows, 1) - jnp.expand_dims(mask_windows, 2)
|
| 225 |
+
att_mask = jnp.where(att_mask != 0.0, float(-100.0), att_mask)
|
| 226 |
+
att_mask = jnp.where(att_mask == 0.0, float(0.0), att_mask)
|
| 227 |
+
else:
|
| 228 |
+
att_mask = None
|
| 229 |
+
|
| 230 |
+
return att_mask
|
| 231 |
+
|
| 232 |
+
@nn.compact
|
| 233 |
+
def __call__(self, x, x_size, training):
|
| 234 |
+
H, W = x_size
|
| 235 |
+
B, L, C = x.shape
|
| 236 |
+
|
| 237 |
+
if min(self.input_resolution) <= self.window_size:
|
| 238 |
+
# if window size is larger than input resolution, we don't partition windows
|
| 239 |
+
self.shift_size = 0
|
| 240 |
+
self.window_size = min(self.input_resolution)
|
| 241 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 242 |
+
|
| 243 |
+
shortcut = x
|
| 244 |
+
x = self.norm_layer()(x)
|
| 245 |
+
x = x.reshape((B, H, W, C))
|
| 246 |
+
|
| 247 |
+
# cyclic shift
|
| 248 |
+
if self.shift_size > 0:
|
| 249 |
+
shifted_x = jnp.roll(x, (-self.shift_size, -self.shift_size), axis=(1, 2))
|
| 250 |
+
else:
|
| 251 |
+
shifted_x = x
|
| 252 |
+
|
| 253 |
+
# partition windows
|
| 254 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 255 |
+
x_windows = x_windows.reshape((-1, self.window_size * self.window_size, C)) # nW*B, window_size*window_size, C
|
| 256 |
+
|
| 257 |
+
#attn_mask = self.variable(
|
| 258 |
+
# "variables",
|
| 259 |
+
# "attn_mask",
|
| 260 |
+
# self.get_att_mask,
|
| 261 |
+
# self.shift_size,
|
| 262 |
+
# self.window_size,
|
| 263 |
+
# self.input_resolution[0],
|
| 264 |
+
# self.input_resolution[1]
|
| 265 |
+
#)
|
| 266 |
+
|
| 267 |
+
attn_mask = self.make_att_mask(self.shift_size, self.window_size, *self.input_resolution)
|
| 268 |
+
|
| 269 |
+
attn = WindowAttention(self.dim, (self.window_size, self.window_size), self.num_heads,
|
| 270 |
+
self.qkv_bias, self.qk_scale, self.attn_drop, self.drop)
|
| 271 |
+
if self.input_resolution == x_size:
|
| 272 |
+
attn_windows = attn(x_windows, attn_mask, training) # nW*B, window_size*window_size, C
|
| 273 |
+
else:
|
| 274 |
+
# test time
|
| 275 |
+
assert not training
|
| 276 |
+
test_mask = self.make_att_mask(self.shift_size, self.window_size, *x_size)
|
| 277 |
+
attn_windows = attn(x_windows, test_mask, training=False)
|
| 278 |
+
|
| 279 |
+
# merge windows
|
| 280 |
+
attn_windows = attn_windows.reshape((-1, self.window_size, self.window_size, C))
|
| 281 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
| 282 |
+
|
| 283 |
+
# reverse cyclic shift
|
| 284 |
+
if self.shift_size > 0:
|
| 285 |
+
x = jnp.roll(shifted_x, (self.shift_size, self.shift_size), axis=(1, 2))
|
| 286 |
+
else:
|
| 287 |
+
x = shifted_x
|
| 288 |
+
|
| 289 |
+
x = x.reshape((B, H * W, C))
|
| 290 |
+
|
| 291 |
+
# FFN
|
| 292 |
+
x = shortcut + DropPath(self.drop_path)(x, training)
|
| 293 |
+
|
| 294 |
+
norm = self.norm_layer()(x)
|
| 295 |
+
mlp = Mlp(in_features=self.dim, hidden_features=int(self.dim * self.mlp_ratio),
|
| 296 |
+
act_layer=self.act_layer, drop=self.drop)(norm, training)
|
| 297 |
+
x = x + DropPath(self.drop_path)(mlp, training)
|
| 298 |
+
|
| 299 |
+
return x
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class PatchMerging(nn.Module):
|
| 303 |
+
inp_res: Iterable[int]
|
| 304 |
+
dim: int
|
| 305 |
+
norm_layer: Callable = LayerNorm
|
| 306 |
+
|
| 307 |
+
@nn.compact
|
| 308 |
+
def __call__(self, inputs):
|
| 309 |
+
batch, n, channels = inputs.shape
|
| 310 |
+
height, width = self.inp_res[0], self.inp_res[1]
|
| 311 |
+
x = jnp.reshape(inputs, (batch, height, width, channels))
|
| 312 |
+
|
| 313 |
+
x0 = x[:, 0::2, 0::2, :]
|
| 314 |
+
x1 = x[:, 1::2, 0::2, :]
|
| 315 |
+
x2 = x[:, 0::2, 1::2, :]
|
| 316 |
+
x3 = x[:, 1::2, 1::2, :]
|
| 317 |
+
|
| 318 |
+
x = jnp.concatenate([x0, x1, x2, x3], axis=-1)
|
| 319 |
+
x = jnp.reshape(x, (batch, -1, 4 * channels))
|
| 320 |
+
x = self.norm_layer()(x)
|
| 321 |
+
x = nn.Dense(2 * self.dim, use_bias=False)(x)
|
| 322 |
+
return x
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class BasicLayer(nn.Module):
|
| 326 |
+
|
| 327 |
+
dim: int
|
| 328 |
+
input_resolution: int
|
| 329 |
+
depth: int
|
| 330 |
+
num_heads: int
|
| 331 |
+
window_size: int
|
| 332 |
+
mlp_ratio: float = 4.
|
| 333 |
+
qkv_bias: bool = True
|
| 334 |
+
qk_scale: Optional[float] = None
|
| 335 |
+
drop: float = 0.
|
| 336 |
+
attn_drop: float = 0.
|
| 337 |
+
drop_path: float = 0.
|
| 338 |
+
norm_layer: Callable = LayerNorm
|
| 339 |
+
downsample: Optional[Callable] = None
|
| 340 |
+
|
| 341 |
+
@nn.compact
|
| 342 |
+
def __call__(self, x, x_size, training):
|
| 343 |
+
for i in range(self.depth):
|
| 344 |
+
x = SwinTransformerBlock(
|
| 345 |
+
self.dim,
|
| 346 |
+
self.input_resolution,
|
| 347 |
+
self.num_heads,
|
| 348 |
+
self.window_size,
|
| 349 |
+
0 if (i % 2 == 0) else self.window_size // 2,
|
| 350 |
+
self.mlp_ratio,
|
| 351 |
+
self.qkv_bias,
|
| 352 |
+
self.qk_scale,
|
| 353 |
+
self.drop,
|
| 354 |
+
self.attn_drop,
|
| 355 |
+
self.drop_path[i] if isinstance(self.drop_path, (list, tuple)) else self.drop_path,
|
| 356 |
+
norm_layer=self.norm_layer
|
| 357 |
+
)(x, x_size, training)
|
| 358 |
+
|
| 359 |
+
if self.downsample is not None:
|
| 360 |
+
x = self.downsample(self.input_resolution, dim=self.dim, norm_layer=self.norm_layer)(x)
|
| 361 |
+
|
| 362 |
+
return x
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class RSTB(nn.Module):
|
| 366 |
+
|
| 367 |
+
dim: int
|
| 368 |
+
input_resolution: int
|
| 369 |
+
depth: int
|
| 370 |
+
num_heads: int
|
| 371 |
+
window_size: int
|
| 372 |
+
mlp_ratio: float = 4.
|
| 373 |
+
qkv_bias: bool = True
|
| 374 |
+
qk_scale: Optional[float] = None
|
| 375 |
+
drop: float = 0.
|
| 376 |
+
attn_drop: float = 0.
|
| 377 |
+
drop_path: float = 0.
|
| 378 |
+
norm_layer: Callable = LayerNorm
|
| 379 |
+
downsample: Optional[Callable] = None
|
| 380 |
+
img_size: int = 224,
|
| 381 |
+
patch_size: int = 4,
|
| 382 |
+
resi_connection: str = '1conv'
|
| 383 |
+
|
| 384 |
+
@nn.compact
|
| 385 |
+
def __call__(self, x, x_size, training):
|
| 386 |
+
res = x
|
| 387 |
+
x = BasicLayer(dim=self.dim,
|
| 388 |
+
input_resolution=self.input_resolution,
|
| 389 |
+
depth=self.depth,
|
| 390 |
+
num_heads=self.num_heads,
|
| 391 |
+
window_size=self.window_size,
|
| 392 |
+
mlp_ratio=self.mlp_ratio,
|
| 393 |
+
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
|
| 394 |
+
drop=self.drop, attn_drop=self.attn_drop,
|
| 395 |
+
drop_path=self.drop_path,
|
| 396 |
+
norm_layer=self.norm_layer,
|
| 397 |
+
downsample=self.downsample)(x, x_size, training)
|
| 398 |
+
|
| 399 |
+
x = PatchUnEmbed(embed_dim=self.dim)(x, x_size)
|
| 400 |
+
|
| 401 |
+
# resi_connection == '1conv':
|
| 402 |
+
x = nn.Conv(self.dim, (3, 3))(x)
|
| 403 |
+
|
| 404 |
+
x = PatchEmbed()(x)
|
| 405 |
+
|
| 406 |
+
return x + res
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class PatchEmbed(nn.Module):
|
| 410 |
+
norm_layer: Optional[Callable] = None
|
| 411 |
+
|
| 412 |
+
@nn.compact
|
| 413 |
+
def __call__(self, x):
|
| 414 |
+
x = x.reshape((x.shape[0], -1, x.shape[-1])) # B Ph Pw C -> B Ph*Pw C
|
| 415 |
+
if self.norm_layer is not None:
|
| 416 |
+
x = self.norm_layer()(x)
|
| 417 |
+
return x
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class PatchUnEmbed(nn.Module):
|
| 421 |
+
embed_dim: int = 96
|
| 422 |
+
|
| 423 |
+
@nn.compact
|
| 424 |
+
def __call__(self, x, x_size):
|
| 425 |
+
B, HW, C = x.shape
|
| 426 |
+
x = x.reshape((B, x_size[0], x_size[1], self.embed_dim))
|
| 427 |
+
return x
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class SwinIR(nn.Module):
|
| 431 |
+
r""" SwinIR JAX implementation
|
| 432 |
+
Args:
|
| 433 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
| 434 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
| 435 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 436 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
| 437 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
| 438 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
| 439 |
+
window_size (int): Window size. Default: 7
|
| 440 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
| 441 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 442 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
| 443 |
+
drop_rate (float): Dropout rate. Default: 0
|
| 444 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
| 445 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
| 446 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 447 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
| 448 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
| 449 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
| 450 |
+
img_range: Image range. 1. or 25I think5.
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
img_size: int = 48
|
| 454 |
+
patch_size: int = 1
|
| 455 |
+
in_chans: int = 3
|
| 456 |
+
embed_dim: int = 180
|
| 457 |
+
depths: tuple = (6, 6, 6, 6, 6, 6)
|
| 458 |
+
num_heads: tuple = (6, 6, 6, 6, 6, 6)
|
| 459 |
+
window_size: int = 8
|
| 460 |
+
mlp_ratio: float = 2.
|
| 461 |
+
qkv_bias: bool = True
|
| 462 |
+
qk_scale: Optional[float] = None
|
| 463 |
+
drop_rate: float = 0.
|
| 464 |
+
attn_drop_rate: float = 0.
|
| 465 |
+
drop_path_rate: float = 0.1
|
| 466 |
+
norm_layer: Callable = LayerNorm
|
| 467 |
+
ape: bool = False
|
| 468 |
+
patch_norm: bool = True
|
| 469 |
+
upscale: int = 2
|
| 470 |
+
img_range: float = 1.
|
| 471 |
+
num_feat: int = 64
|
| 472 |
+
|
| 473 |
+
def pad(self, x):
|
| 474 |
+
_, h, w, _ = x.shape
|
| 475 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
| 476 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
| 477 |
+
x = jnp.pad(x, ((0, 0), (0, mod_pad_h), (0, mod_pad_w), (0, 0)), 'reflect')
|
| 478 |
+
return x
|
| 479 |
+
|
| 480 |
+
@nn.compact
|
| 481 |
+
def __call__(self, x, training):
|
| 482 |
+
_, h_before, w_before, _ = x.shape
|
| 483 |
+
x = self.pad(x)
|
| 484 |
+
_, h, w, _ = x.shape
|
| 485 |
+
patches_resolution = [self.img_size // self.patch_size] * 2
|
| 486 |
+
num_patches = patches_resolution[0] * patches_resolution[1]
|
| 487 |
+
|
| 488 |
+
# conv_first
|
| 489 |
+
x = nn.Conv(self.embed_dim, (3, 3))(x)
|
| 490 |
+
res = x
|
| 491 |
+
|
| 492 |
+
# feature extraction
|
| 493 |
+
x_size = (h, w)
|
| 494 |
+
x = PatchEmbed(self.norm_layer if self.patch_norm else None)(x)
|
| 495 |
+
|
| 496 |
+
if self.ape:
|
| 497 |
+
absolute_pos_embed = \
|
| 498 |
+
self.param('ape', trunc_normal(std=.02), (1, num_patches, self.embed_dim))
|
| 499 |
+
x = x + absolute_pos_embed
|
| 500 |
+
|
| 501 |
+
x = nn.Dropout(self.drop_rate, deterministic=not training)(x)
|
| 502 |
+
|
| 503 |
+
dpr = [x.item() for x in np.linspace(0, self.drop_path_rate, sum(self.depths))]
|
| 504 |
+
for i_layer in range(len(self.depths)):
|
| 505 |
+
x = RSTB(
|
| 506 |
+
dim=self.embed_dim,
|
| 507 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
| 508 |
+
depth=self.depths[i_layer],
|
| 509 |
+
num_heads=self.num_heads[i_layer],
|
| 510 |
+
window_size=self.window_size,
|
| 511 |
+
mlp_ratio=self.mlp_ratio,
|
| 512 |
+
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
|
| 513 |
+
drop=self.drop_rate, attn_drop=self.attn_drop_rate,
|
| 514 |
+
drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
|
| 515 |
+
norm_layer=self.norm_layer,
|
| 516 |
+
downsample=None,
|
| 517 |
+
img_size=self.img_size,
|
| 518 |
+
patch_size=self.patch_size)(x, x_size, training)
|
| 519 |
+
|
| 520 |
+
x = self.norm_layer()(x) # B L C
|
| 521 |
+
x = PatchUnEmbed(self.embed_dim)(x, x_size)
|
| 522 |
+
|
| 523 |
+
# conv_after_body
|
| 524 |
+
x = nn.Conv(self.embed_dim, (3, 3))(x)
|
| 525 |
+
x = x + res
|
| 526 |
+
|
| 527 |
+
# conv_before_upsample
|
| 528 |
+
x = nn.activation.leaky_relu(nn.Conv(self.num_feat, (3, 3))(x))
|
| 529 |
+
|
| 530 |
+
# revert padding
|
| 531 |
+
x = x[:, :-(h - h_before) or None, :-(w - w_before) or None]
|
| 532 |
+
return x
|
model/tail.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import flax.linen as nn
|
| 2 |
+
|
| 3 |
+
from .convnext import ConvNeXt
|
| 4 |
+
from .swin_ir import SwinIR
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def build_tail(size: str):
|
| 8 |
+
""" Convenience function to build the three tails described in the paper. """
|
| 9 |
+
if size == 'air':
|
| 10 |
+
return lambda x, _: x
|
| 11 |
+
elif size == 'plus':
|
| 12 |
+
blocks = [(64, 3, True)] * 6 + [(96, 3, True)] * 7 + [(128, 3, True)] * 3
|
| 13 |
+
return ConvNeXt(blocks)
|
| 14 |
+
elif size == 'pro':
|
| 15 |
+
return SwinIR(depths=[7, 6], num_heads=[6, 6])
|
| 16 |
+
else:
|
| 17 |
+
raise NotImplementedError('size: ' + size)
|
| 18 |
+
|
model/thera.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
from flax.core import unfreeze, freeze
|
| 5 |
+
import jax.numpy as jnp
|
| 6 |
+
import flax.linen as nn
|
| 7 |
+
from jaxtyping import Array, ArrayLike, PyTree
|
| 8 |
+
|
| 9 |
+
from .edsr import EDSR
|
| 10 |
+
from .rdn import RDN
|
| 11 |
+
from .hyper import Hypernetwork
|
| 12 |
+
from .tail import build_tail
|
| 13 |
+
from .init import uniform_between, linear_up
|
| 14 |
+
from utils import make_grid, interpolate_grid, repeat_vmap
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Thermal(nn.Module):
|
| 18 |
+
w0_scale: float = 1.
|
| 19 |
+
|
| 20 |
+
@nn.compact
|
| 21 |
+
def __call__(self, x: ArrayLike, t, norm, k) -> Array:
|
| 22 |
+
phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:])
|
| 23 |
+
return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TheraField(nn.Module):
|
| 27 |
+
dim_hidden: int
|
| 28 |
+
dim_out: int
|
| 29 |
+
w0: float = 1.
|
| 30 |
+
c: float = 6.
|
| 31 |
+
|
| 32 |
+
@nn.compact
|
| 33 |
+
def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array:
|
| 34 |
+
# coordinate projection according to shared components ("first layer")
|
| 35 |
+
x = x @ components
|
| 36 |
+
|
| 37 |
+
# thermal activations
|
| 38 |
+
norm = jnp.linalg.norm(components, axis=-2)
|
| 39 |
+
x = Thermal(self.w0)(x, t, norm, k)
|
| 40 |
+
|
| 41 |
+
# linear projection from hidden to output space ("second layer")
|
| 42 |
+
w_std = math.sqrt(self.c / self.dim_hidden) / self.w0
|
| 43 |
+
dense_init_fn = uniform_between(-w_std, w_std)
|
| 44 |
+
x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x)
|
| 45 |
+
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Thera:
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
hidden_dim: int,
|
| 54 |
+
out_dim: int,
|
| 55 |
+
backbone: nn.Module,
|
| 56 |
+
tail: nn.Module,
|
| 57 |
+
k_init: float = None,
|
| 58 |
+
components_init_scale: float = None
|
| 59 |
+
):
|
| 60 |
+
self.hidden_dim = hidden_dim
|
| 61 |
+
self.k_init = k_init
|
| 62 |
+
self.components_init_scale = components_init_scale
|
| 63 |
+
|
| 64 |
+
# single TheraField object whose `apply` method is used for all grid cells
|
| 65 |
+
self.field = TheraField(hidden_dim, out_dim)
|
| 66 |
+
|
| 67 |
+
# infer output size of the hypernetwork from a sample pass through the field;
|
| 68 |
+
# key doesnt matter as field params are only used for size inference
|
| 69 |
+
sample_params = self.field.init(jax.random.PRNGKey(0),
|
| 70 |
+
jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim)))
|
| 71 |
+
sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params)
|
| 72 |
+
param_shapes = [p.shape for p in sample_params_flat]
|
| 73 |
+
|
| 74 |
+
self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def)
|
| 75 |
+
|
| 76 |
+
def init(self, key, sample_source) -> PyTree:
|
| 77 |
+
keys = jax.random.split(key, 2)
|
| 78 |
+
sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,))
|
| 79 |
+
params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords))
|
| 80 |
+
|
| 81 |
+
params['params']['k'] = jnp.array(self.k_init)
|
| 82 |
+
params['params']['components'] = \
|
| 83 |
+
linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim))
|
| 84 |
+
|
| 85 |
+
return freeze(params)
|
| 86 |
+
|
| 87 |
+
def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array:
|
| 88 |
+
"""
|
| 89 |
+
Performs a forward pass through the hypernetwork to obtain an encoding.
|
| 90 |
+
"""
|
| 91 |
+
return self.hypernet.apply(
|
| 92 |
+
params, source, method=self.hypernet.get_encoding, **kwargs)
|
| 93 |
+
|
| 94 |
+
def apply_decoder(
|
| 95 |
+
self,
|
| 96 |
+
params: PyTree,
|
| 97 |
+
encoding: ArrayLike,
|
| 98 |
+
coords: ArrayLike,
|
| 99 |
+
t: ArrayLike,
|
| 100 |
+
return_jac: bool = False
|
| 101 |
+
) -> Array | tuple[Array, Array]:
|
| 102 |
+
"""
|
| 103 |
+
Performs a forward prediction through a grid of HxW Thera fields,
|
| 104 |
+
informed by `encoding`, at spatial and temporal coordinates
|
| 105 |
+
`coords` and `t`, respectively.
|
| 106 |
+
args:
|
| 107 |
+
params: Field parameters, shape (B, H, W, N)
|
| 108 |
+
encoding: Encoding tensor, shape (B, H, W, C)
|
| 109 |
+
coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2)
|
| 110 |
+
t: Temporal coordinates, shape (B, 1)
|
| 111 |
+
"""
|
| 112 |
+
phi_params: PyTree = self.hypernet.apply(
|
| 113 |
+
params, encoding, coords, method=self.hypernet.get_params_at_coords)
|
| 114 |
+
|
| 115 |
+
# create local coordinate systems
|
| 116 |
+
source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1]))
|
| 117 |
+
source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1))
|
| 118 |
+
interp_coords = interpolate_grid(coords, source_coords)
|
| 119 |
+
rel_coords = (coords - interp_coords)
|
| 120 |
+
rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3])
|
| 121 |
+
rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2])
|
| 122 |
+
|
| 123 |
+
# three maps over params, coords; one over t; dont map k and components
|
| 124 |
+
in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)]
|
| 125 |
+
apply_field = repeat_vmap(self.field.apply, in_axes)
|
| 126 |
+
out = apply_field(phi_params, rel_coords, t, params['params']['k'],
|
| 127 |
+
params['params']['components'])
|
| 128 |
+
|
| 129 |
+
if return_jac:
|
| 130 |
+
apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes)
|
| 131 |
+
jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'],
|
| 132 |
+
params['params']['components'])
|
| 133 |
+
return out, jac
|
| 134 |
+
|
| 135 |
+
return out
|
| 136 |
+
|
| 137 |
+
def apply(
|
| 138 |
+
self,
|
| 139 |
+
params: ArrayLike,
|
| 140 |
+
source: ArrayLike,
|
| 141 |
+
coords: ArrayLike,
|
| 142 |
+
t: ArrayLike,
|
| 143 |
+
return_jac: bool = False,
|
| 144 |
+
**kwargs
|
| 145 |
+
) -> Array:
|
| 146 |
+
"""
|
| 147 |
+
Performs a forward pass through the Thera model.
|
| 148 |
+
"""
|
| 149 |
+
encoding = self.apply_encoder(params, source, **kwargs)
|
| 150 |
+
out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac)
|
| 151 |
+
return out
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def build_thera(
|
| 155 |
+
out_dim: int,
|
| 156 |
+
backbone: str,
|
| 157 |
+
size: str,
|
| 158 |
+
k_init: float = None,
|
| 159 |
+
components_init_scale: float = None
|
| 160 |
+
):
|
| 161 |
+
"""
|
| 162 |
+
Convenience function for building the three Thera sizes described in the paper.
|
| 163 |
+
"""
|
| 164 |
+
hidden_dim = 32 if size == 'air' else 512
|
| 165 |
+
|
| 166 |
+
if backbone == 'edsr-baseline':
|
| 167 |
+
backbone_module = EDSR(None, num_blocks=16, num_feats=64)
|
| 168 |
+
elif backbone == 'rdn':
|
| 169 |
+
backbone_module = RDN()
|
| 170 |
+
else:
|
| 171 |
+
raise NotImplementedError(backbone)
|
| 172 |
+
|
| 173 |
+
tail_module = build_tail(size)
|
| 174 |
+
|
| 175 |
+
return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale)
|
requirements.txt
CHANGED
|
@@ -1,6 +1,37 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
diffusers
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 2 |
+
|
| 3 |
+
ConfigArgParse==1.7
|
| 4 |
+
Pillow==10.0.0
|
| 5 |
+
chex==0.1.7
|
| 6 |
diffusers
|
| 7 |
+
einops==0.6.1
|
| 8 |
+
flax==0.6.10
|
| 9 |
+
flaxmodels==0.1.3
|
| 10 |
+
jax==0.4.11
|
| 11 |
+
jaxlib==0.4.11+cuda11.cudnn86
|
| 12 |
+
jaxtyping==0.2.20
|
| 13 |
+
ml-dtypes==0.1.0
|
| 14 |
+
numpy==1.24.1
|
| 15 |
+
nvidia-cublas-cu11==11.11.3.6
|
| 16 |
+
nvidia-cuda-cupti-cu11==11.8.87
|
| 17 |
+
nvidia-cuda-nvcc-cu11==11.8.89
|
| 18 |
+
nvidia-cuda-runtime-cu11==11.8.89
|
| 19 |
+
nvidia-cudnn-cu11==8.9.2.26
|
| 20 |
+
nvidia-cufft-cu11==10.9.0.58
|
| 21 |
+
nvidia-cusolver-cu11==11.4.1.48
|
| 22 |
+
nvidia-cusparse-cu11==11.7.5.86
|
| 23 |
+
opt-einsum==3.3.0
|
| 24 |
+
optax==0.2.0
|
| 25 |
+
orbax-checkpoint==0.2.4
|
| 26 |
+
peft
|
| 27 |
+
scipy==1.10.1
|
| 28 |
+
timm==0.9.6
|
| 29 |
+
torch
|
| 30 |
+
torchvision
|
| 31 |
+
tqdm==4.65.0
|
| 32 |
+
transformers==4.46.3
|
| 33 |
+
wandb
|
| 34 |
+
|
| 35 |
+
gradio==4.44.1
|
| 36 |
+
gradio_imageslider==0.0.20
|
| 37 |
+
spaces
|
super_resolve.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
from argparse import ArgumentParser, Namespace
|
| 4 |
+
import pickle
|
| 5 |
+
|
| 6 |
+
import jax
|
| 7 |
+
from jax import jit
|
| 8 |
+
import jax.numpy as jnp
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from model import build_thera
|
| 13 |
+
from utils import make_grid, interpolate_grid
|
| 14 |
+
|
| 15 |
+
MEAN = jnp.array([.4488, .4371, .4040])
|
| 16 |
+
VAR = jnp.array([.25, .25, .25])
|
| 17 |
+
PATCH_SIZE = 256
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def process_single(source, apply_encoder, apply_decoder, params, target_shape):
|
| 21 |
+
t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
|
| 22 |
+
coords_nearest = jnp.asarray(make_grid(target_shape)[None])
|
| 23 |
+
source_up = interpolate_grid(coords_nearest, source[None])
|
| 24 |
+
source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]
|
| 25 |
+
|
| 26 |
+
encoding = apply_encoder(params, source)
|
| 27 |
+
coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords
|
| 28 |
+
out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)
|
| 29 |
+
|
| 30 |
+
for h_min in range(0, coords.shape[1], PATCH_SIZE):
|
| 31 |
+
h_max = min(h_min + PATCH_SIZE, coords.shape[1])
|
| 32 |
+
for w_min in range(0, coords.shape[2], PATCH_SIZE):
|
| 33 |
+
# apply decoder with one patch of coordinates
|
| 34 |
+
w_max = min(w_min + PATCH_SIZE, coords.shape[2])
|
| 35 |
+
coords_patch = coords[:, h_min:h_max, w_min:w_max]
|
| 36 |
+
out_patch = apply_decoder(params, encoding, coords_patch, t)
|
| 37 |
+
out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)
|
| 38 |
+
|
| 39 |
+
out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
|
| 40 |
+
out += source_up
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def process(source, model, params, target_shape, do_ensemble=True):
|
| 45 |
+
apply_encoder = jit(model.apply_encoder)
|
| 46 |
+
apply_decoder = jit(model.apply_decoder)
|
| 47 |
+
|
| 48 |
+
outs = []
|
| 49 |
+
for i_rot in range(4 if do_ensemble else 1):
|
| 50 |
+
source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
|
| 51 |
+
target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
|
| 52 |
+
out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
|
| 53 |
+
outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))
|
| 54 |
+
|
| 55 |
+
out = jnp.stack(outs).mean(0).clip(0., 1.)
|
| 56 |
+
return jnp.rint(out[0] * 255).astype(jnp.uint8)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main(args: Namespace):
|
| 60 |
+
source = np.asarray(Image.open(args.in_file)) / 255.
|
| 61 |
+
|
| 62 |
+
if args.scale is not None:
|
| 63 |
+
if args.size is not None:
|
| 64 |
+
raise ValueError('Cannot specify both size and scale')
|
| 65 |
+
target_shape = (
|
| 66 |
+
round(source.shape[0] * args.scale),
|
| 67 |
+
round(source.shape[1] * args.scale),
|
| 68 |
+
)
|
| 69 |
+
elif args.size is not None:
|
| 70 |
+
target_shape = args.size
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError('Must specify either size or scale')
|
| 73 |
+
|
| 74 |
+
with open(args.checkpoint, 'rb') as fh:
|
| 75 |
+
check = pickle.load(fh)
|
| 76 |
+
params, backbone, size = check['model'], check['backbone'], check['size']
|
| 77 |
+
|
| 78 |
+
model = build_thera(3, backbone, size)
|
| 79 |
+
|
| 80 |
+
out = process(source, model, params, target_shape, not args.no_ensemble)
|
| 81 |
+
|
| 82 |
+
Image.fromarray(np.asarray(out)).save(args.out_file)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def parse_args() -> Namespace:
|
| 86 |
+
parser = ArgumentParser()
|
| 87 |
+
parser.add_argument('in_file')
|
| 88 |
+
parser.add_argument('out_file')
|
| 89 |
+
parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
|
| 90 |
+
parser.add_argument('--size', type=int, nargs=2,
|
| 91 |
+
help='Target size (h, w), mutually exclusive with --scale')
|
| 92 |
+
parser.add_argument('--checkpoint', help='Path to checkpoint file')
|
| 93 |
+
parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
|
| 94 |
+
return parser.parse_args()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == '__main__':
|
| 98 |
+
args = parse_args()
|
| 99 |
+
main(args)
|
utils.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def repeat_vmap(fun, in_axes=[0]):
|
| 8 |
+
for axes in in_axes:
|
| 9 |
+
fun = jax.vmap(fun, in_axes=axes)
|
| 10 |
+
return fun
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def make_grid(patch_size: int | tuple[int, int]):
|
| 14 |
+
if isinstance(patch_size, int):
|
| 15 |
+
patch_size = (patch_size, patch_size)
|
| 16 |
+
offset_h, offset_w = 1 / (2 * np.array(patch_size))
|
| 17 |
+
space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
|
| 18 |
+
space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
|
| 19 |
+
return np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1) # [h, w]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def interpolate_grid(coords, grid, order=0):
|
| 23 |
+
"""
|
| 24 |
+
args:
|
| 25 |
+
coords: Tensor of shape (B, H, W, 2) with coordinates in [-0.5, 0.5]
|
| 26 |
+
grid: Tensor of shape (B, H', W', C)
|
| 27 |
+
returns:
|
| 28 |
+
Tensor of shape (B, H, W, C) with interpolated values
|
| 29 |
+
"""
|
| 30 |
+
# convert [-0.5, 0.5] -> [0, size], where pixel centers are expected at
|
| 31 |
+
# [-0.5 + 1 / (2*size), ..., 0.5 - 1 / (2*size)]
|
| 32 |
+
coords = coords.transpose((0, 3, 1, 2))
|
| 33 |
+
coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
|
| 34 |
+
coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
|
| 35 |
+
map_coordinates = partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest')
|
| 36 |
+
return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
|