Commit
·
3dc5131
1
Parent(s):
5ccd64d
flow support localized noise
Browse files- app.py +37 -97
- src/demo.py +78 -0
- src/utils.py +20 -1
app.py
CHANGED
@@ -1,37 +1,18 @@
|
|
1 |
-
import os
|
2 |
-
import cv2
|
3 |
-
import sys
|
4 |
import torch
|
5 |
-
import numpy as np
|
6 |
-
import gradio as gr
|
7 |
-
import matplotlib.pyplot as plt
|
8 |
-
from src.model import ConditionalUNet
|
9 |
-
from huggingface_hub import hf_hub_download
|
10 |
import time
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
def resize(image,size=(200,200)):
|
17 |
-
stretch_near = cv2.resize(image, size, interpolation = cv2.INTER_LINEAR)
|
18 |
-
return stretch_near
|
19 |
-
|
20 |
|
21 |
-
model_diff = ConditionalUNet().to(device)
|
22 |
-
model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",
|
23 |
-
cache_dir="models")
|
24 |
-
print("Diff Downloaded!")
|
25 |
-
model_diff.load_state_dict(torch.load(model_path, map_location=device))
|
26 |
-
model_diff.eval()
|
27 |
|
|
|
28 |
|
29 |
-
model_flow = ConditionalUNet().to(device)
|
30 |
-
model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",
|
31 |
-
cache_dir="models")
|
32 |
-
print("Flow Downloaded!")
|
33 |
-
model_flow.load_state_dict(torch.load(model_path, map_location=device))
|
34 |
-
model_flow.eval()
|
35 |
|
36 |
@torch.no_grad()
|
37 |
def generate_diffusion_intermediates_streaming(label):
|
@@ -39,6 +20,7 @@ def generate_diffusion_intermediates_streaming(label):
|
|
39 |
betas = torch.linspace(1e-4, 0.02, timesteps)
|
40 |
alphas = 1.0 - betas
|
41 |
alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
|
|
|
42 |
|
43 |
x = torch.randn(1, *img_shape).to(device)
|
44 |
y = torch.tensor([label], dtype=torch.long, device=device)
|
@@ -63,63 +45,33 @@ def generate_diffusion_intermediates_streaming(label):
|
|
63 |
x += v.sqrt() * noise
|
64 |
x = x.clamp(-1, 1)
|
65 |
|
|
|
66 |
|
67 |
-
if t in [499, 399, 299, 199, 99, 0]:
|
68 |
-
step_idx = {499: 6, 399: 7, 299: 8, 199: 9, 99: 10, 0: 11}[t]
|
69 |
-
v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy()
|
70 |
-
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
|
71 |
-
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3]
|
72 |
-
vel_colored = (vel_colored * 255).astype(np.uint8)
|
73 |
-
outputs[step_idx] = resize(vel_colored)
|
74 |
-
yield tuple(outputs)
|
75 |
-
|
76 |
-
outputs[12] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
|
77 |
-
|
78 |
-
if t in [400, 300, 200, 100, 1, 0]:
|
79 |
-
step_idx = {400: 1, 300: 2, 200: 3, 100: 4, 1: 5, 0 :12}[t]
|
80 |
-
if t==0:
|
81 |
-
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
|
82 |
-
else:
|
83 |
-
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy())
|
84 |
-
yield tuple(outputs)
|
85 |
if t % 10 == 0:
|
86 |
yield tuple(outputs)
|
87 |
time.sleep(0.06)
|
88 |
-
#time.sleep(0.1)
|
89 |
-
yield tuple(outputs)
|
90 |
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
"""Genera una imagen con ruido solo en un círculo en el centro."""
|
94 |
-
B, C, H, W = shape
|
95 |
-
assert C == 1, "Solo imágenes en escala de grises."
|
96 |
|
97 |
-
# Crear máscara circular
|
98 |
-
yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
|
99 |
-
center_y, center_x = H // 2, W // 2
|
100 |
-
mask = ((yy - center_y)**2 + (xx - center_x)**2) <= radius**2
|
101 |
-
mask = mask.float().unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
|
102 |
|
103 |
-
# Aplicar máscara a ruido
|
104 |
-
noise = torch.randn(B, C, H, W)
|
105 |
-
localized_noise = noise * mask + -1*(1-mask) # solo hay ruido dentro del círculo
|
106 |
-
#mask = ((yy - center_y)**2 + (xx - center_x)**2) >= (radius//2)**2
|
107 |
-
#mask = mask.float().unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
|
108 |
-
#localized_noise = localized_noise * mask + -1*(1-mask) # solo hay ruido dentro del círculo
|
109 |
-
return localized_noise
|
110 |
|
111 |
|
112 |
@torch.no_grad()
|
113 |
-
def generate_flow_intermediates_streaming(label):
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
116 |
y = torch.full((1,), label, dtype=torch.long, device=device)
|
117 |
steps = 50
|
118 |
dt = 1.0 / steps
|
119 |
|
120 |
-
images = [(x + 1) / 2.0] # initial noise
|
121 |
-
vel_magnitudes = []
|
122 |
-
|
123 |
# Inicial
|
124 |
img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
|
125 |
|
@@ -131,36 +83,16 @@ def generate_flow_intermediates_streaming(label):
|
|
131 |
time.sleep(0.2)
|
132 |
|
133 |
|
134 |
-
for i in range(steps):
|
135 |
-
|
136 |
t = torch.full((1,), i * dt, device=device)
|
137 |
v = model_flow(x, t, y)
|
138 |
x = x + v * dt
|
139 |
-
|
140 |
-
outputs[12] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
|
141 |
-
if i in [10,20,30,40,48,49]: #
|
142 |
-
step_idx = {10: 1, 20: 2, 30: 3, 40: 4, 48: 5,49:12}[i] #,
|
143 |
-
if i==49:
|
144 |
-
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
|
145 |
-
else:
|
146 |
-
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
|
147 |
-
yield tuple(outputs)
|
148 |
-
|
149 |
-
|
150 |
-
# Compute velocity magnitude and convert to numpy for visualization
|
151 |
-
if i in [0,11,21,31,41,49]:
|
152 |
-
v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() # Clamp to max value for better contrast
|
153 |
-
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
|
154 |
-
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] # (H,W,3)
|
155 |
-
vel_colored = (vel_colored * 255).astype(np.uint8)
|
156 |
-
step_idx = {0: 6, 11: 7, 21: 8, 31: 9, 41: 10, 49:11}[i]
|
157 |
-
outputs[step_idx] = resize(vel_colored)
|
158 |
-
yield tuple(outputs)
|
159 |
if t % 10 == 0:
|
160 |
yield tuple(outputs)
|
161 |
time.sleep(0.06)
|
162 |
-
|
163 |
-
|
164 |
yield tuple(outputs)
|
165 |
|
166 |
|
@@ -196,7 +128,13 @@ with gr.Blocks() as demo:
|
|
196 |
btn_d.click(fn=generate_diffusion_intermediates_streaming, inputs=label_d, outputs=outs_d+diff_noise_imgs+diff_result_imgs)
|
197 |
|
198 |
with gr.Tab("Flow Matching"):
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
btn_f = gr.Button("Generate")
|
201 |
with gr.Row():
|
202 |
outs_f = [
|
@@ -221,8 +159,10 @@ with gr.Blocks() as demo:
|
|
221 |
flow_result_imgs = [
|
222 |
gr.Image(label="Flow step=49",streaming=True),
|
223 |
]
|
224 |
-
btn_f.click(fn=generate_flow_intermediates_streaming, inputs=label_f, outputs=outs_f+flow_vel_imgs+flow_result_imgs)
|
225 |
|
226 |
-
demo.launch()
|
227 |
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
|
|
|
|
|
|
|
|
|
|
2 |
import time
|
3 |
+
import gradio as gr
|
4 |
+
from src.utils import generate_centered_gaussian_noise
|
5 |
+
from src.demo import resize,plot_flow,load_models,plot_diff
|
6 |
|
7 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
8 |
+
img_shape = (1, 28, 28)
|
9 |
+
ENV = "DEPLOY"
|
10 |
+
TIME_SLEEP = 0.05
|
11 |
|
|
|
|
|
|
|
|
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
model_diff_standard,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
@torch.no_grad()
|
18 |
def generate_diffusion_intermediates_streaming(label):
|
|
|
20 |
betas = torch.linspace(1e-4, 0.02, timesteps)
|
21 |
alphas = 1.0 - betas
|
22 |
alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
|
23 |
+
model_diff = model_diff_standard
|
24 |
|
25 |
x = torch.randn(1, *img_shape).to(device)
|
26 |
y = torch.tensor([label], dtype=torch.long, device=device)
|
|
|
45 |
x += v.sqrt() * noise
|
46 |
x = x.clamp(-1, 1)
|
47 |
|
48 |
+
outputs = plot_diff(outputs,x,t,noise_pred)
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
if t % 10 == 0:
|
51 |
yield tuple(outputs)
|
52 |
time.sleep(0.06)
|
|
|
|
|
53 |
|
54 |
+
if ENV=="LOCAL":
|
55 |
+
time.sleep(TIME_SLEEP)
|
56 |
|
57 |
+
yield tuple(outputs)
|
|
|
|
|
|
|
58 |
|
|
|
|
|
|
|
|
|
|
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
|
62 |
@torch.no_grad()
|
63 |
+
def generate_flow_intermediates_streaming(label,noise_type):
|
64 |
+
if noise_type=="Localized":
|
65 |
+
x = generate_centered_gaussian_noise((1, *img_shape)).to(device)
|
66 |
+
model_flow = model_flow_localized
|
67 |
+
else:
|
68 |
+
x = torch.randn(1, *img_shape).to(device)
|
69 |
+
model_flow = model_flow_standard
|
70 |
+
|
71 |
y = torch.full((1,), label, dtype=torch.long, device=device)
|
72 |
steps = 50
|
73 |
dt = 1.0 / steps
|
74 |
|
|
|
|
|
|
|
75 |
# Inicial
|
76 |
img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
|
77 |
|
|
|
83 |
time.sleep(0.2)
|
84 |
|
85 |
|
86 |
+
for i in range(steps):
|
|
|
87 |
t = torch.full((1,), i * dt, device=device)
|
88 |
v = model_flow(x, t, y)
|
89 |
x = x + v * dt
|
90 |
+
outputs = plot_flow(outputs,i,x,dt,v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
if t % 10 == 0:
|
92 |
yield tuple(outputs)
|
93 |
time.sleep(0.06)
|
94 |
+
if ENV=="LOCAL":
|
95 |
+
time.sleep(TIME_SLEEP)
|
96 |
yield tuple(outputs)
|
97 |
|
98 |
|
|
|
128 |
btn_d.click(fn=generate_diffusion_intermediates_streaming, inputs=label_d, outputs=outs_d+diff_noise_imgs+diff_result_imgs)
|
129 |
|
130 |
with gr.Tab("Flow Matching"):
|
131 |
+
with gr.Row():
|
132 |
+
noise_selector_f = gr.Radio(
|
133 |
+
["Standard", "Localized"],
|
134 |
+
label="Noise Type:",
|
135 |
+
value="Standard" # o "Standard", según quieras el valor por defecto
|
136 |
+
)
|
137 |
+
label_f = gr.Slider(0, 9, step=1, label="Digit Label")
|
138 |
btn_f = gr.Button("Generate")
|
139 |
with gr.Row():
|
140 |
outs_f = [
|
|
|
159 |
flow_result_imgs = [
|
160 |
gr.Image(label="Flow step=49",streaming=True),
|
161 |
]
|
162 |
+
btn_f.click(fn=generate_flow_intermediates_streaming, inputs=[label_f,noise_selector_f], outputs=outs_f+flow_vel_imgs+flow_result_imgs)
|
163 |
|
|
|
164 |
|
165 |
+
if ENV=="DEPLOY":
|
166 |
+
demo.launch()
|
167 |
+
else:
|
168 |
+
demo.launch(share=True, server_port=9071)
|
src/demo.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from .model import ConditionalUNet
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
|
9 |
+
def load_models(ENV,device):
|
10 |
+
if ENV=="DEPLOY":
|
11 |
+
model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",cache_dir="models")
|
12 |
+
else:
|
13 |
+
model_path = "outputs/diffusion/diffusion_model.pth"
|
14 |
+
print("Diff Downloaded!")
|
15 |
+
model_diff_standard = ConditionalUNet().to(device)
|
16 |
+
model_diff_standard.load_state_dict(torch.load(model_path, map_location=device))
|
17 |
+
model_diff_standard.eval()
|
18 |
+
|
19 |
+
if ENV=="DEPLOY":
|
20 |
+
model_path_standard = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",cache_dir="models")
|
21 |
+
model_path_localized = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model_localized_noise.pth",cache_dir="models")
|
22 |
+
else:
|
23 |
+
model_path_standard = "outputs/flow_matching/flow_model.pth"
|
24 |
+
model_path_localized = "outputs/flow_matching/flow_model_localized_noise.pth"
|
25 |
+
print("Flow Downloaded!")
|
26 |
+
model_flow_standard = ConditionalUNet().to(device)
|
27 |
+
model_flow_standard.load_state_dict(torch.load(model_path_standard, map_location=device))
|
28 |
+
model_flow_standard.eval()
|
29 |
+
model_flow_localized = ConditionalUNet().to(device)
|
30 |
+
model_flow_localized.load_state_dict(torch.load(model_path_localized, map_location=device))
|
31 |
+
model_flow_localized.eval()
|
32 |
+
|
33 |
+
return model_diff_standard,model_flow_standard,model_flow_localized
|
34 |
+
def resize(image,size=(200,200)):
|
35 |
+
stretch_near = cv2.resize(image, size, interpolation = cv2.INTER_LINEAR)
|
36 |
+
return stretch_near
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
def plot_diff(outputs,x,t,noise_pred):
|
41 |
+
|
42 |
+
if t in [499, 399, 299, 199, 99, 0]:
|
43 |
+
step_idx = {499: 6, 399: 7, 299: 8, 199: 9, 99: 10, 0: 11}[t]
|
44 |
+
v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy()
|
45 |
+
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
|
46 |
+
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3]
|
47 |
+
vel_colored = (vel_colored * 255).astype(np.uint8)
|
48 |
+
outputs[step_idx] = resize(vel_colored)
|
49 |
+
|
50 |
+
outputs[12] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
|
51 |
+
|
52 |
+
if t in [400, 300, 200, 100, 1, 0]:
|
53 |
+
step_idx = {400: 1, 300: 2, 200: 3, 100: 4, 1: 5, 0 :12}[t]
|
54 |
+
if t==0:
|
55 |
+
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
|
56 |
+
else:
|
57 |
+
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy())
|
58 |
+
|
59 |
+
return outputs
|
60 |
+
|
61 |
+
def plot_flow(outputs,i,x,dt,v):
|
62 |
+
# Compute velocity magnitude and convert to numpy for visualization
|
63 |
+
outputs[12] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
|
64 |
+
if i in [10,20,30,40,48,49]: #
|
65 |
+
step_idx = {10: 1, 20: 2, 30: 3, 40: 4, 48: 5,49:12}[i] #,
|
66 |
+
if i==49:
|
67 |
+
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
|
68 |
+
else:
|
69 |
+
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
|
70 |
+
|
71 |
+
if i in [0,11,21,31,41,49]:
|
72 |
+
v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() # Clamp to max value for better contrast
|
73 |
+
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
|
74 |
+
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] # (H,W,3)
|
75 |
+
vel_colored = (vel_colored * 255).astype(np.uint8)
|
76 |
+
step_idx = {0: 6, 11: 7, 21: 8, 31: 9, 41: 10, 49:11}[i]
|
77 |
+
outputs[step_idx] = resize(vel_colored)
|
78 |
+
return outputs
|
src/utils.py
CHANGED
@@ -11,4 +11,23 @@ def set_seed(seed):
|
|
11 |
torch.manual_seed(seed)
|
12 |
torch.cuda.manual_seed(seed)
|
13 |
torch.cuda.manual_seed_all(seed)
|
14 |
-
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
torch.manual_seed(seed)
|
12 |
torch.cuda.manual_seed(seed)
|
13 |
torch.cuda.manual_seed_all(seed)
|
14 |
+
torch.backends.cudnn.deterministic = True
|
15 |
+
|
16 |
+
def generate_centered_gaussian_noise(shape=(1, 1, 28, 28), sigma=5.0, mu=0):
|
17 |
+
B, C, H, W = shape
|
18 |
+
assert C == 1, "only image gray"
|
19 |
+
|
20 |
+
yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
|
21 |
+
yy = yy.to(torch.float32)
|
22 |
+
xx = xx.to(torch.float32)
|
23 |
+
|
24 |
+
center_y, center_x = H / 2, W / 2
|
25 |
+
gauss = torch.exp(-((yy - center_y)**2 + (xx - center_x)**2) / (2 * sigma**2))
|
26 |
+
gauss = gauss / gauss.max() # Normalization to [0, 1]
|
27 |
+
gauss = gauss.unsqueeze(0).unsqueeze(0).expand(B, C, H, W)
|
28 |
+
|
29 |
+
noise = mu + torch.randn(B, C, H, W) # Noise with mean mu
|
30 |
+
localized_noise = noise * gauss + mu * (1 - gauss)
|
31 |
+
|
32 |
+
return localized_noise
|
33 |
+
|