CristianLazoQuispe's picture
logging
74ec8db
raw
history blame
7.41 kB
import time
import torch
import logging
import gradio as gr
logging.basicConfig(level=logging.INFO)
from src.utils import generate_centered_gaussian_noise
from src.demo import resize,plot_flow,load_models,plot_diff
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img_shape = (1, 28, 28)
ENV = "DEPLOY"
TIME_SLEEP = 0.05
timesteps = 500
betas = torch.linspace(1e-4, 0.02, timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
model_diff,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
@torch.no_grad()
def generate_diffusion_intermediates_streaming(label):
logging.info("🚀 Starting Diffusion Generation")
total_start = time.time()
x = torch.randn(1, *img_shape).to(device)
y = torch.tensor([label], dtype=torch.long, device=device)
outputs = [None] * 13
outputs[0] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
yield tuple(outputs)
time.sleep(0.2)
for t in reversed(range(timesteps)):
step_start = time.time()
t_tensor = torch.full((x.size(0),), t, device=device, dtype=torch.float)
# Forward pass
model_start = time.time()
noise_pred = model_diff(x, t_tensor, y)
model_time = time.time() - model_start
# Denoising step
step_compute_start = time.time()
x = (1 / alphas[t].sqrt()) * (x - noise_pred * betas[t] / (1 - alphas_cumprod[t]).sqrt())
if t > 0:
noise = torch.randn_like(x)
v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
x += v.sqrt() * noise
x = x.clamp(-1, 1)
step_compute_time = time.time() - step_compute_start
# Plotting
plot_start = time.time()
outputs = plot_diff(outputs, x, t, noise_pred)
plot_time = time.time() - plot_start
# Logging
step_time = time.time() - step_start
total_time = time.time() - total_start
if t % 50 == 0 or t in [400, 300, 200, 100, 0]:
logging.info(f"Diff [t={t:03d}] total={total_time:.3f}s | total_step={step_time:.3f}s | model={model_time:.3f}s | step={step_compute_time:.3f}s | plot={plot_time:.3f}s")
if t % 20 == 0 or t in [499, 399, 299, 199, 99, 0, 400, 300, 200, 100, 1]:
yield tuple(outputs)
time.sleep(0.06)
if ENV == "LOCAL":
time.sleep(TIME_SLEEP)
total_time = time.time() - total_start
logging.info(f" Finished diffusion in {total_time:.2f}s")
yield tuple(outputs)
import logging
logging.basicConfig(level=logging.INFO)
@torch.no_grad()
def generate_flow_intermediates_streaming(label, noise_type):
logging.info("🚀 Starting Flow Matching Generation")
total_start = time.time()
# Select noise and model
if noise_type == "Localized":
x = generate_centered_gaussian_noise((1, *img_shape)).to(device)
model_flow = model_flow_localized
else:
x = torch.randn(1, *img_shape).to(device)
model_flow = model_flow_standard
y = torch.full((1,), label, dtype=torch.long, device=device)
steps = 50
dt = 1.0 / steps
outputs = [None] * 13
outputs[0] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
yield tuple(outputs)
time.sleep(0.2)
for i in range(steps):
step_start = time.time()
t = torch.full((1,), i * dt, device=device)
# Forward pass
model_start = time.time()
v = model_flow(x, t, y)
model_time = time.time() - model_start
# Flow step
flow_step_start = time.time()
x = x + v * dt
flow_step_time = time.time() - flow_step_start
# Plotting
plot_start = time.time()
outputs = plot_flow(outputs, i, x, dt, v)
plot_time = time.time() - plot_start
# Logging
step_time = time.time() - step_start
total_time = time.time() - total_start
if i % 10 == 0 or i in [0, 25, 49]:
logging.info(f"Flow [step={i:02d}] total={total_time:.3f}s | total_step={step_time:.3f}s | model={model_time:.3f}s | step={flow_step_time:.3f}s | plot={plot_time:.3f}s")
if i % 2 == 0:
yield tuple(outputs)
time.sleep(0.15)
if ENV == "LOCAL":
time.sleep(TIME_SLEEP)
total_time = time.time() - total_start
logging.info(f"Finished flow matching in {total_time:.2f}s")
yield tuple(outputs)
with gr.Blocks() as demo:
gr.Markdown("# Conditional MNIST Generation: Diffusion vs Flow Matching")
with gr.Tab("Diffusion"):
label_d = gr.Slider(0, 9, step=1, label="Digit Label")
btn_d = gr.Button("Generate")
with gr.Row():
outs_d = [
gr.Image(label="Noise",streaming=True),
gr.Image(label="Diffusion t=400",streaming=True),
gr.Image(label="Diffusion t=300",streaming=True),
gr.Image(label="Diffusion t=200",streaming=True),
gr.Image(label="Diffusion t=100",streaming=True),
gr.Image(label="Diffusion t=1",streaming=True),
]
with gr.Row():
#400, 300, 200, 100,0
diff_noise_imgs = [
gr.Image(label="Noise pred t=500",streaming=True),
gr.Image(label="Noise pred t=400",streaming=True),
gr.Image(label="Noise pred t=300",streaming=True),
gr.Image(label="Noise pred t=200",streaming=True),
gr.Image(label="Noise pred t=100",streaming=True),
gr.Image(label="Noise pred t=1",streaming=True),
]
with gr.Row():
diff_result_imgs = [
gr.Image(label="Diffusion t=0",streaming=True),
]
btn_d.click(fn=generate_diffusion_intermediates_streaming, inputs=label_d, outputs=outs_d+diff_noise_imgs+diff_result_imgs)
with gr.Tab("Flow Matching"):
with gr.Row():
noise_selector_f = gr.Radio(
["Standard", "Localized"],
label="Noise Type:",
value="Standard" # o "Standard", según quieras el valor por defecto
)
label_f = gr.Slider(0, 9, step=1, label="Digit Label")
btn_f = gr.Button("Generate")
with gr.Row():
outs_f = [
gr.Image(label="Noise"),
gr.Image(label="Flow step=10"),
gr.Image(label="Flow step=20"),
gr.Image(label="Flow step=30"),
gr.Image(label="Flow step=40"),
gr.Image(label="Flow step=48"),
]
with gr.Row():
#100,200,300,400,499
flow_vel_imgs = [
gr.Image(label="Velocity step=0"),
gr.Image(label="Velocity step=10"),
gr.Image(label="Velocity step=20"),
gr.Image(label="Velocity step=30"),
gr.Image(label="Velocity step=40"),
gr.Image(label="Velocity step=48")
]
with gr.Row():
flow_result_imgs = [
gr.Image(label="Flow step=49",streaming=True),
]
btn_f.click(fn=generate_flow_intermediates_streaming, inputs=[label_f,noise_selector_f], outputs=outs_f+flow_vel_imgs+flow_result_imgs)
if ENV=="DEPLOY":
demo.launch()
else:
demo.launch(share=True, server_port=9071)