Commit
路
d36e58a
1
Parent(s):
29ff960
streaming outputs
Browse files
app.py
CHANGED
@@ -7,8 +7,10 @@ import gradio as gr
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
from src.model import ConditionalUNet
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
-
|
11 |
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
12 |
img_shape = (1, 28, 28)
|
13 |
|
14 |
|
@@ -33,17 +35,24 @@ model_flow.load_state_dict(torch.load(model_path, map_location=device))
|
|
33 |
model_flow.eval()
|
34 |
|
35 |
@torch.no_grad()
|
36 |
-
def
|
37 |
timesteps = 500
|
38 |
-
img_shape = (1, 28, 28)
|
39 |
betas = torch.linspace(1e-4, 0.02, timesteps)
|
40 |
alphas = 1.0 - betas
|
41 |
alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
|
42 |
|
43 |
x = torch.randn(1, *img_shape).to(device)
|
44 |
y = torch.tensor([label], dtype=torch.long, device=device)
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
for t in reversed(range(timesteps)):
|
49 |
t_tensor = torch.full((x.size(0),), t, device=device, dtype=torch.float)
|
@@ -53,22 +62,26 @@ def generate_diffusion_intermediates(label):
|
|
53 |
noise = torch.randn(1, *img_shape).to(device)
|
54 |
v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
|
55 |
x += v.sqrt() * noise
|
56 |
-
|
57 |
x = x.clamp(-1, 1)
|
58 |
-
if t in [400, 300, 200, 100,0]:
|
59 |
-
#print("t:",t)
|
60 |
-
img_np = ((x + 1) / 2)[0, 0].cpu().numpy()
|
61 |
-
intermediates.append(resize(img_np))
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
|
67 |
-
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3]
|
68 |
vel_colored = (vel_colored * 255).astype(np.uint8)
|
69 |
-
|
|
|
|
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
|
74 |
def generate_localized_noise(shape, radius=5):
|
@@ -89,7 +102,7 @@ def generate_localized_noise(shape, radius=5):
|
|
89 |
|
90 |
|
91 |
@torch.no_grad()
|
92 |
-
def
|
93 |
x = torch.randn(1, *img_shape).to(device)
|
94 |
#x = generate_localized_noise((1, 1, 28, 28), radius=12).to(device)
|
95 |
y = torch.full((1,), label, dtype=torch.long, device=device)
|
@@ -98,23 +111,43 @@ def generate_flow_intermediates(label):
|
|
98 |
|
99 |
images = [(x + 1) / 2.0] # initial noise
|
100 |
vel_magnitudes = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
for i in range(steps):
|
102 |
|
103 |
t = torch.full((1,), i * dt, device=device)
|
104 |
v = model_flow(x, t, y)
|
105 |
x = x + v * dt
|
106 |
|
107 |
-
if i in [10,20,30,40,49]:
|
108 |
-
images.append((x + 1) / 2.0)
|
|
|
|
|
|
|
|
|
109 |
# Compute velocity magnitude and convert to numpy for visualization
|
110 |
-
if i in [0,
|
111 |
v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() # Clamp to max value for better contrast
|
112 |
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
|
113 |
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] # (H,W,3)
|
114 |
vel_colored = (vel_colored * 255).astype(np.uint8)
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
return [resize(images[0][0][0].clamp(0, 1).cpu().numpy())]+[resize(img[0][0].clamp(0, 1).cpu().numpy()) for img in images[-5:]]+vel_magnitudes
|
118 |
|
119 |
with gr.Blocks() as demo:
|
120 |
gr.Markdown("# Conditional MNIST Generation: Diffusion vs Flow Matching")
|
@@ -124,24 +157,28 @@ with gr.Blocks() as demo:
|
|
124 |
btn_d = gr.Button("Generate")
|
125 |
with gr.Row():
|
126 |
outs_d = [
|
127 |
-
gr.Image(label="Noise"),
|
128 |
-
gr.Image(label="Diffusion t=400"),
|
129 |
-
gr.Image(label="Diffusion t=300"),
|
130 |
-
gr.Image(label="Diffusion t=200"),
|
131 |
-
gr.Image(label="Diffusion t=100"),
|
132 |
-
gr.Image(label="Diffusion t=
|
133 |
]
|
134 |
with gr.Row():
|
135 |
#400, 300, 200, 100,0
|
136 |
-
|
137 |
-
gr.Image(label="Noise pred t=500"),
|
138 |
-
gr.Image(label="Noise pred t=400"),
|
139 |
-
gr.Image(label="Noise pred t=300"),
|
140 |
-
gr.Image(label="Noise pred t=200"),
|
141 |
-
gr.Image(label="Noise pred t=100"),
|
142 |
-
gr.Image(label="Noise pred t=
|
143 |
]
|
144 |
-
|
|
|
|
|
|
|
|
|
145 |
|
146 |
with gr.Tab("Flow Matching"):
|
147 |
label_f = gr.Slider(0, 9, step=1, label="Digit Label")
|
@@ -153,7 +190,7 @@ with gr.Blocks() as demo:
|
|
153 |
gr.Image(label="Flow step=20"),
|
154 |
gr.Image(label="Flow step=30"),
|
155 |
gr.Image(label="Flow step=40"),
|
156 |
-
gr.Image(label="Flow step=
|
157 |
]
|
158 |
with gr.Row():
|
159 |
#100,200,300,400,499
|
@@ -163,10 +200,13 @@ with gr.Blocks() as demo:
|
|
163 |
gr.Image(label="Velocity step=20"),
|
164 |
gr.Image(label="Velocity step=30"),
|
165 |
gr.Image(label="Velocity step=40"),
|
166 |
-
gr.Image(label="Velocity step=
|
167 |
]
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
demo.launch()
|
172 |
-
#demo.launch(share=False, server_port=9070)
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
from src.model import ConditionalUNet
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
+
import time
|
11 |
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
12 |
+
|
13 |
+
device = 'cpu'
|
14 |
img_shape = (1, 28, 28)
|
15 |
|
16 |
|
|
|
35 |
model_flow.eval()
|
36 |
|
37 |
@torch.no_grad()
|
38 |
+
def generate_diffusion_intermediates_streaming(label):
|
39 |
timesteps = 500
|
|
|
40 |
betas = torch.linspace(1e-4, 0.02, timesteps)
|
41 |
alphas = 1.0 - betas
|
42 |
alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
|
43 |
|
44 |
x = torch.randn(1, *img_shape).to(device)
|
45 |
y = torch.tensor([label], dtype=torch.long, device=device)
|
46 |
+
|
47 |
+
# Inicial
|
48 |
+
img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
|
49 |
+
|
50 |
+
# Para mantener la posici贸n de cada imagen
|
51 |
+
outputs = [None] * 13
|
52 |
+
yield tuple(outputs)
|
53 |
+
outputs[0] = resize(img_np)
|
54 |
+
yield tuple(outputs)
|
55 |
+
#time.sleep(0.5)
|
56 |
|
57 |
for t in reversed(range(timesteps)):
|
58 |
t_tensor = torch.full((x.size(0),), t, device=device, dtype=torch.float)
|
|
|
62 |
noise = torch.randn(1, *img_shape).to(device)
|
63 |
v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
|
64 |
x += v.sqrt() * noise
|
|
|
65 |
x = x.clamp(-1, 1)
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
|
68 |
+
if t in [499, 399, 299, 199, 99, 0]:
|
69 |
+
step_idx = {499: 6, 399: 7, 299: 8, 199: 9, 99: 10, 0: 11}[t]
|
70 |
+
v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy()
|
71 |
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
|
72 |
+
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3]
|
73 |
vel_colored = (vel_colored * 255).astype(np.uint8)
|
74 |
+
outputs[step_idx] = resize(vel_colored)
|
75 |
+
yield tuple(outputs)
|
76 |
+
time.sleep(0.5)
|
77 |
|
78 |
+
if t in [400, 300, 200, 100, 1, 0]:
|
79 |
+
step_idx = {400: 1, 300: 2, 200: 3, 100: 4, 1: 5, 0 :12}[t]
|
80 |
+
if t==0:
|
81 |
+
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(400,400))
|
82 |
+
else:
|
83 |
+
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy())
|
84 |
+
yield tuple(outputs)
|
85 |
|
86 |
|
87 |
def generate_localized_noise(shape, radius=5):
|
|
|
102 |
|
103 |
|
104 |
@torch.no_grad()
|
105 |
+
def generate_flow_intermediates_streaming(label):
|
106 |
x = torch.randn(1, *img_shape).to(device)
|
107 |
#x = generate_localized_noise((1, 1, 28, 28), radius=12).to(device)
|
108 |
y = torch.full((1,), label, dtype=torch.long, device=device)
|
|
|
111 |
|
112 |
images = [(x + 1) / 2.0] # initial noise
|
113 |
vel_magnitudes = []
|
114 |
+
|
115 |
+
# Inicial
|
116 |
+
img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
|
117 |
+
|
118 |
+
# Para mantener la posici贸n de cada imagen
|
119 |
+
outputs = [None] * 13
|
120 |
+
yield tuple(outputs)
|
121 |
+
outputs[0] = resize(img_np)
|
122 |
+
yield tuple(outputs)
|
123 |
+
#time.sleep(0.5)
|
124 |
+
|
125 |
+
|
126 |
for i in range(steps):
|
127 |
|
128 |
t = torch.full((1,), i * dt, device=device)
|
129 |
v = model_flow(x, t, y)
|
130 |
x = x + v * dt
|
131 |
|
132 |
+
if i in [10,20,30,40,48,49]:
|
133 |
+
#images.append((x + 1) / 2.0)
|
134 |
+
step_idx = {10: 1, 20: 2, 30: 3, 40: 4, 48: 5, 49:12}[i]
|
135 |
+
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
|
136 |
+
yield tuple(outputs)
|
137 |
+
|
138 |
# Compute velocity magnitude and convert to numpy for visualization
|
139 |
+
if i in [0,11,21,31,41,49]:
|
140 |
v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() # Clamp to max value for better contrast
|
141 |
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
|
142 |
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] # (H,W,3)
|
143 |
vel_colored = (vel_colored * 255).astype(np.uint8)
|
144 |
+
step_idx = {0: 6, 11: 7, 21: 8, 31: 9, 41: 10, 49:11}[i]
|
145 |
+
if i==49:
|
146 |
+
outputs[step_idx] = resize(vel_colored, (400, 400))
|
147 |
+
else:
|
148 |
+
outputs[step_idx] = resize(vel_colored)
|
149 |
+
yield tuple(outputs)
|
150 |
|
|
|
151 |
|
152 |
with gr.Blocks() as demo:
|
153 |
gr.Markdown("# Conditional MNIST Generation: Diffusion vs Flow Matching")
|
|
|
157 |
btn_d = gr.Button("Generate")
|
158 |
with gr.Row():
|
159 |
outs_d = [
|
160 |
+
gr.Image(label="Noise",streaming=True),
|
161 |
+
gr.Image(label="Diffusion t=400",streaming=True),
|
162 |
+
gr.Image(label="Diffusion t=300",streaming=True),
|
163 |
+
gr.Image(label="Diffusion t=200",streaming=True),
|
164 |
+
gr.Image(label="Diffusion t=100",streaming=True),
|
165 |
+
gr.Image(label="Diffusion t=1",streaming=True),
|
166 |
]
|
167 |
with gr.Row():
|
168 |
#400, 300, 200, 100,0
|
169 |
+
diff_noise_imgs = [
|
170 |
+
gr.Image(label="Noise pred t=500",streaming=True),
|
171 |
+
gr.Image(label="Noise pred t=400",streaming=True),
|
172 |
+
gr.Image(label="Noise pred t=300",streaming=True),
|
173 |
+
gr.Image(label="Noise pred t=200",streaming=True),
|
174 |
+
gr.Image(label="Noise pred t=100",streaming=True),
|
175 |
+
gr.Image(label="Noise pred t=1",streaming=True),
|
176 |
]
|
177 |
+
with gr.Row():
|
178 |
+
diff_result_imgs = [
|
179 |
+
gr.Image(label="Diffusion t=0",streaming=True),
|
180 |
+
]
|
181 |
+
btn_d.click(fn=generate_diffusion_intermediates_streaming, inputs=label_d, outputs=outs_d+diff_noise_imgs+diff_result_imgs)
|
182 |
|
183 |
with gr.Tab("Flow Matching"):
|
184 |
label_f = gr.Slider(0, 9, step=1, label="Digit Label")
|
|
|
190 |
gr.Image(label="Flow step=20"),
|
191 |
gr.Image(label="Flow step=30"),
|
192 |
gr.Image(label="Flow step=40"),
|
193 |
+
gr.Image(label="Flow step=48"),
|
194 |
]
|
195 |
with gr.Row():
|
196 |
#100,200,300,400,499
|
|
|
200 |
gr.Image(label="Velocity step=20"),
|
201 |
gr.Image(label="Velocity step=30"),
|
202 |
gr.Image(label="Velocity step=40"),
|
203 |
+
gr.Image(label="Velocity step=48")
|
204 |
]
|
205 |
+
with gr.Row():
|
206 |
+
flow_result_imgs = [
|
207 |
+
gr.Image(label="Flow step=49",streaming=True),
|
208 |
+
]
|
209 |
+
btn_f.click(fn=generate_flow_intermediates_streaming, inputs=label_f, outputs=outs_f+flow_vel_imgs+flow_result_imgs)
|
210 |
|
211 |
+
#demo.launch()
|
212 |
+
demo.launch(share=False, server_port=9071)
|
|
|
|