|
import os |
|
import cv2 |
|
import sys |
|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
from src.model import ConditionalUNet |
|
from huggingface_hub import hf_hub_download |
|
import time |
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
|
img_shape = (1, 28, 28) |
|
|
|
|
|
def resize(image,size=(200,200)): |
|
stretch_near = cv2.resize(image, size, interpolation = cv2.INTER_LINEAR) |
|
return stretch_near |
|
|
|
|
|
model_diff = ConditionalUNet().to(device) |
|
model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth", |
|
cache_dir="models") |
|
print("Diff Downloaded!") |
|
model_diff.load_state_dict(torch.load(model_path, map_location=device)) |
|
model_diff.eval() |
|
|
|
|
|
model_flow = ConditionalUNet().to(device) |
|
model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth", |
|
cache_dir="models") |
|
print("Flow Downloaded!") |
|
model_flow.load_state_dict(torch.load(model_path, map_location=device)) |
|
model_flow.eval() |
|
|
|
@torch.no_grad() |
|
def generate_diffusion_intermediates_streaming(label): |
|
timesteps = 500 |
|
betas = torch.linspace(1e-4, 0.02, timesteps) |
|
alphas = 1.0 - betas |
|
alphas_cumprod = torch.cumprod(alphas, dim=0).to(device) |
|
|
|
x = torch.randn(1, *img_shape).to(device) |
|
y = torch.tensor([label], dtype=torch.long, device=device) |
|
|
|
|
|
img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy() |
|
|
|
|
|
outputs = [None] * 13 |
|
yield tuple(outputs) |
|
outputs[0] = resize(img_np) |
|
yield tuple(outputs) |
|
time.sleep(0.2) |
|
|
|
for t in reversed(range(timesteps)): |
|
t_tensor = torch.full((x.size(0),), t, device=device, dtype=torch.float) |
|
noise_pred = model_diff(x, t_tensor, y) |
|
x = (1 / alphas[t].sqrt()) * (x - noise_pred * betas[t] / (1 - alphas_cumprod[t]).sqrt() ) |
|
if t > 0: |
|
noise = torch.randn(1, *img_shape).to(device) |
|
v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t] |
|
x += v.sqrt() * noise |
|
x = x.clamp(-1, 1) |
|
|
|
|
|
if t in [499, 399, 299, 199, 99, 0]: |
|
step_idx = {499: 6, 399: 7, 299: 8, 199: 9, 99: 10, 0: 11}[t] |
|
v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy() |
|
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5) |
|
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] |
|
vel_colored = (vel_colored * 255).astype(np.uint8) |
|
outputs[step_idx] = resize(vel_colored) |
|
yield tuple(outputs) |
|
|
|
outputs[12] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300)) |
|
|
|
if t in [400, 300, 200, 100, 1, 0]: |
|
step_idx = {400: 1, 300: 2, 200: 3, 100: 4, 1: 5, 0 :12}[t] |
|
if t==0: |
|
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300)) |
|
else: |
|
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy()) |
|
yield tuple(outputs) |
|
if t % 10 == 0: |
|
yield tuple(outputs) |
|
time.sleep(0.06) |
|
|
|
yield tuple(outputs) |
|
|
|
|
|
def generate_localized_noise(shape, radius=5): |
|
"""Genera una imagen con ruido solo en un c铆rculo en el centro.""" |
|
B, C, H, W = shape |
|
assert C == 1, "Solo im谩genes en escala de grises." |
|
|
|
|
|
yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') |
|
center_y, center_x = H // 2, W // 2 |
|
mask = ((yy - center_y)**2 + (xx - center_x)**2) <= radius**2 |
|
mask = mask.float().unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
noise = torch.randn(B, C, H, W) |
|
localized_noise = noise * mask + -1*(1-mask) |
|
|
|
|
|
|
|
return localized_noise |
|
|
|
|
|
@torch.no_grad() |
|
def generate_flow_intermediates_streaming(label): |
|
x = torch.randn(1, *img_shape).to(device) |
|
|
|
y = torch.full((1,), label, dtype=torch.long, device=device) |
|
steps = 50 |
|
dt = 1.0 / steps |
|
|
|
images = [(x + 1) / 2.0] |
|
vel_magnitudes = [] |
|
|
|
|
|
img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy() |
|
|
|
|
|
outputs = [None] * 13 |
|
yield tuple(outputs) |
|
outputs[0] = resize(img_np) |
|
yield tuple(outputs) |
|
time.sleep(0.2) |
|
|
|
|
|
for i in range(steps): |
|
|
|
t = torch.full((1,), i * dt, device=device) |
|
v = model_flow(x, t, y) |
|
x = x + v * dt |
|
|
|
outputs[12] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300)) |
|
if i in [10,20,30,40,48,49]: |
|
step_idx = {10: 1, 20: 2, 30: 3, 40: 4, 48: 5,49:12}[i] |
|
if i==49: |
|
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300)) |
|
else: |
|
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()) |
|
yield tuple(outputs) |
|
|
|
|
|
|
|
if i in [0,11,21,31,41,49]: |
|
v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() |
|
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5) |
|
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] |
|
vel_colored = (vel_colored * 255).astype(np.uint8) |
|
step_idx = {0: 6, 11: 7, 21: 8, 31: 9, 41: 10, 49:11}[i] |
|
outputs[step_idx] = resize(vel_colored) |
|
yield tuple(outputs) |
|
if t % 10 == 0: |
|
yield tuple(outputs) |
|
time.sleep(0.06) |
|
|
|
|
|
yield tuple(outputs) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Conditional MNIST Generation: Diffusion vs Flow Matching") |
|
|
|
with gr.Tab("Diffusion"): |
|
label_d = gr.Slider(0, 9, step=1, label="Digit Label") |
|
btn_d = gr.Button("Generate") |
|
with gr.Row(): |
|
outs_d = [ |
|
gr.Image(label="Noise",streaming=True), |
|
gr.Image(label="Diffusion t=400",streaming=True), |
|
gr.Image(label="Diffusion t=300",streaming=True), |
|
gr.Image(label="Diffusion t=200",streaming=True), |
|
gr.Image(label="Diffusion t=100",streaming=True), |
|
gr.Image(label="Diffusion t=1",streaming=True), |
|
] |
|
with gr.Row(): |
|
|
|
diff_noise_imgs = [ |
|
gr.Image(label="Noise pred t=500",streaming=True), |
|
gr.Image(label="Noise pred t=400",streaming=True), |
|
gr.Image(label="Noise pred t=300",streaming=True), |
|
gr.Image(label="Noise pred t=200",streaming=True), |
|
gr.Image(label="Noise pred t=100",streaming=True), |
|
gr.Image(label="Noise pred t=1",streaming=True), |
|
] |
|
with gr.Row(): |
|
diff_result_imgs = [ |
|
gr.Image(label="Diffusion t=0",streaming=True), |
|
] |
|
btn_d.click(fn=generate_diffusion_intermediates_streaming, inputs=label_d, outputs=outs_d+diff_noise_imgs+diff_result_imgs) |
|
|
|
with gr.Tab("Flow Matching"): |
|
label_f = gr.Slider(0, 9, step=1, label="Digit Label") |
|
btn_f = gr.Button("Generate") |
|
with gr.Row(): |
|
outs_f = [ |
|
gr.Image(label="Noise"), |
|
gr.Image(label="Flow step=10"), |
|
gr.Image(label="Flow step=20"), |
|
gr.Image(label="Flow step=30"), |
|
gr.Image(label="Flow step=40"), |
|
gr.Image(label="Flow step=48"), |
|
] |
|
with gr.Row(): |
|
|
|
flow_vel_imgs = [ |
|
gr.Image(label="Velocity step=0"), |
|
gr.Image(label="Velocity step=10"), |
|
gr.Image(label="Velocity step=20"), |
|
gr.Image(label="Velocity step=30"), |
|
gr.Image(label="Velocity step=40"), |
|
gr.Image(label="Velocity step=48") |
|
] |
|
with gr.Row(): |
|
flow_result_imgs = [ |
|
gr.Image(label="Flow step=49",streaming=True), |
|
] |
|
btn_f.click(fn=generate_flow_intermediates_streaming, inputs=label_f, outputs=outs_f+flow_vel_imgs+flow_result_imgs) |
|
|
|
demo.launch() |
|
|
|
|