CristianLazoQuispe commited on
Commit
3dc5131
·
1 Parent(s): 5ccd64d

flow support localized noise

Browse files
Files changed (3) hide show
  1. app.py +37 -97
  2. src/demo.py +78 -0
  3. src/utils.py +20 -1
app.py CHANGED
@@ -1,37 +1,18 @@
1
- import os
2
- import cv2
3
- import sys
4
  import torch
5
- import numpy as np
6
- import gradio as gr
7
- import matplotlib.pyplot as plt
8
- from src.model import ConditionalUNet
9
- from huggingface_hub import hf_hub_download
10
  import time
11
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
12
- #device = 'cpu'
13
- img_shape = (1, 28, 28)
14
 
 
 
 
 
15
 
16
- def resize(image,size=(200,200)):
17
- stretch_near = cv2.resize(image, size, interpolation = cv2.INTER_LINEAR)
18
- return stretch_near
19
-
20
 
21
- model_diff = ConditionalUNet().to(device)
22
- model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",
23
- cache_dir="models")
24
- print("Diff Downloaded!")
25
- model_diff.load_state_dict(torch.load(model_path, map_location=device))
26
- model_diff.eval()
27
 
 
28
 
29
- model_flow = ConditionalUNet().to(device)
30
- model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",
31
- cache_dir="models")
32
- print("Flow Downloaded!")
33
- model_flow.load_state_dict(torch.load(model_path, map_location=device))
34
- model_flow.eval()
35
 
36
  @torch.no_grad()
37
  def generate_diffusion_intermediates_streaming(label):
@@ -39,6 +20,7 @@ def generate_diffusion_intermediates_streaming(label):
39
  betas = torch.linspace(1e-4, 0.02, timesteps)
40
  alphas = 1.0 - betas
41
  alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
 
42
 
43
  x = torch.randn(1, *img_shape).to(device)
44
  y = torch.tensor([label], dtype=torch.long, device=device)
@@ -63,63 +45,33 @@ def generate_diffusion_intermediates_streaming(label):
63
  x += v.sqrt() * noise
64
  x = x.clamp(-1, 1)
65
 
 
66
 
67
- if t in [499, 399, 299, 199, 99, 0]:
68
- step_idx = {499: 6, 399: 7, 299: 8, 199: 9, 99: 10, 0: 11}[t]
69
- v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy()
70
- v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
71
- vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3]
72
- vel_colored = (vel_colored * 255).astype(np.uint8)
73
- outputs[step_idx] = resize(vel_colored)
74
- yield tuple(outputs)
75
-
76
- outputs[12] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
77
-
78
- if t in [400, 300, 200, 100, 1, 0]:
79
- step_idx = {400: 1, 300: 2, 200: 3, 100: 4, 1: 5, 0 :12}[t]
80
- if t==0:
81
- outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
82
- else:
83
- outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy())
84
- yield tuple(outputs)
85
  if t % 10 == 0:
86
  yield tuple(outputs)
87
  time.sleep(0.06)
88
- #time.sleep(0.1)
89
- yield tuple(outputs)
90
 
 
 
91
 
92
- def generate_localized_noise(shape, radius=5):
93
- """Genera una imagen con ruido solo en un círculo en el centro."""
94
- B, C, H, W = shape
95
- assert C == 1, "Solo imágenes en escala de grises."
96
 
97
- # Crear máscara circular
98
- yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
99
- center_y, center_x = H // 2, W // 2
100
- mask = ((yy - center_y)**2 + (xx - center_x)**2) <= radius**2
101
- mask = mask.float().unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
102
 
103
- # Aplicar máscara a ruido
104
- noise = torch.randn(B, C, H, W)
105
- localized_noise = noise * mask + -1*(1-mask) # solo hay ruido dentro del círculo
106
- #mask = ((yy - center_y)**2 + (xx - center_x)**2) >= (radius//2)**2
107
- #mask = mask.float().unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
108
- #localized_noise = localized_noise * mask + -1*(1-mask) # solo hay ruido dentro del círculo
109
- return localized_noise
110
 
111
 
112
  @torch.no_grad()
113
- def generate_flow_intermediates_streaming(label):
114
- x = torch.randn(1, *img_shape).to(device)
115
- #x = generate_localized_noise((1, 1, 28, 28), radius=12).to(device)
 
 
 
 
 
116
  y = torch.full((1,), label, dtype=torch.long, device=device)
117
  steps = 50
118
  dt = 1.0 / steps
119
 
120
- images = [(x + 1) / 2.0] # initial noise
121
- vel_magnitudes = []
122
-
123
  # Inicial
124
  img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
125
 
@@ -131,36 +83,16 @@ def generate_flow_intermediates_streaming(label):
131
  time.sleep(0.2)
132
 
133
 
134
- for i in range(steps):
135
-
136
  t = torch.full((1,), i * dt, device=device)
137
  v = model_flow(x, t, y)
138
  x = x + v * dt
139
-
140
- outputs[12] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
141
- if i in [10,20,30,40,48,49]: #
142
- step_idx = {10: 1, 20: 2, 30: 3, 40: 4, 48: 5,49:12}[i] #,
143
- if i==49:
144
- outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
145
- else:
146
- outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
147
- yield tuple(outputs)
148
-
149
-
150
- # Compute velocity magnitude and convert to numpy for visualization
151
- if i in [0,11,21,31,41,49]:
152
- v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() # Clamp to max value for better contrast
153
- v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
154
- vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] # (H,W,3)
155
- vel_colored = (vel_colored * 255).astype(np.uint8)
156
- step_idx = {0: 6, 11: 7, 21: 8, 31: 9, 41: 10, 49:11}[i]
157
- outputs[step_idx] = resize(vel_colored)
158
- yield tuple(outputs)
159
  if t % 10 == 0:
160
  yield tuple(outputs)
161
  time.sleep(0.06)
162
-
163
- #time.sleep(0.1)
164
  yield tuple(outputs)
165
 
166
 
@@ -196,7 +128,13 @@ with gr.Blocks() as demo:
196
  btn_d.click(fn=generate_diffusion_intermediates_streaming, inputs=label_d, outputs=outs_d+diff_noise_imgs+diff_result_imgs)
197
 
198
  with gr.Tab("Flow Matching"):
199
- label_f = gr.Slider(0, 9, step=1, label="Digit Label")
 
 
 
 
 
 
200
  btn_f = gr.Button("Generate")
201
  with gr.Row():
202
  outs_f = [
@@ -221,8 +159,10 @@ with gr.Blocks() as demo:
221
  flow_result_imgs = [
222
  gr.Image(label="Flow step=49",streaming=True),
223
  ]
224
- btn_f.click(fn=generate_flow_intermediates_streaming, inputs=label_f, outputs=outs_f+flow_vel_imgs+flow_result_imgs)
225
 
226
- demo.launch()
227
 
228
- #demo.launch(share=False, server_port=9071)
 
 
 
 
 
 
 
1
  import torch
 
 
 
 
 
2
  import time
3
+ import gradio as gr
4
+ from src.utils import generate_centered_gaussian_noise
5
+ from src.demo import resize,plot_flow,load_models,plot_diff
6
 
7
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+ img_shape = (1, 28, 28)
9
+ ENV = "DEPLOY"
10
+ TIME_SLEEP = 0.05
11
 
 
 
 
 
12
 
 
 
 
 
 
 
13
 
14
+ model_diff_standard,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
15
 
 
 
 
 
 
 
16
 
17
  @torch.no_grad()
18
  def generate_diffusion_intermediates_streaming(label):
 
20
  betas = torch.linspace(1e-4, 0.02, timesteps)
21
  alphas = 1.0 - betas
22
  alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
23
+ model_diff = model_diff_standard
24
 
25
  x = torch.randn(1, *img_shape).to(device)
26
  y = torch.tensor([label], dtype=torch.long, device=device)
 
45
  x += v.sqrt() * noise
46
  x = x.clamp(-1, 1)
47
 
48
+ outputs = plot_diff(outputs,x,t,noise_pred)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  if t % 10 == 0:
51
  yield tuple(outputs)
52
  time.sleep(0.06)
 
 
53
 
54
+ if ENV=="LOCAL":
55
+ time.sleep(TIME_SLEEP)
56
 
57
+ yield tuple(outputs)
 
 
 
58
 
 
 
 
 
 
59
 
 
 
 
 
 
 
 
60
 
61
 
62
  @torch.no_grad()
63
+ def generate_flow_intermediates_streaming(label,noise_type):
64
+ if noise_type=="Localized":
65
+ x = generate_centered_gaussian_noise((1, *img_shape)).to(device)
66
+ model_flow = model_flow_localized
67
+ else:
68
+ x = torch.randn(1, *img_shape).to(device)
69
+ model_flow = model_flow_standard
70
+
71
  y = torch.full((1,), label, dtype=torch.long, device=device)
72
  steps = 50
73
  dt = 1.0 / steps
74
 
 
 
 
75
  # Inicial
76
  img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
77
 
 
83
  time.sleep(0.2)
84
 
85
 
86
+ for i in range(steps):
 
87
  t = torch.full((1,), i * dt, device=device)
88
  v = model_flow(x, t, y)
89
  x = x + v * dt
90
+ outputs = plot_flow(outputs,i,x,dt,v)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  if t % 10 == 0:
92
  yield tuple(outputs)
93
  time.sleep(0.06)
94
+ if ENV=="LOCAL":
95
+ time.sleep(TIME_SLEEP)
96
  yield tuple(outputs)
97
 
98
 
 
128
  btn_d.click(fn=generate_diffusion_intermediates_streaming, inputs=label_d, outputs=outs_d+diff_noise_imgs+diff_result_imgs)
129
 
130
  with gr.Tab("Flow Matching"):
131
+ with gr.Row():
132
+ noise_selector_f = gr.Radio(
133
+ ["Standard", "Localized"],
134
+ label="Noise Type:",
135
+ value="Standard" # o "Standard", según quieras el valor por defecto
136
+ )
137
+ label_f = gr.Slider(0, 9, step=1, label="Digit Label")
138
  btn_f = gr.Button("Generate")
139
  with gr.Row():
140
  outs_f = [
 
159
  flow_result_imgs = [
160
  gr.Image(label="Flow step=49",streaming=True),
161
  ]
162
+ btn_f.click(fn=generate_flow_intermediates_streaming, inputs=[label_f,noise_selector_f], outputs=outs_f+flow_vel_imgs+flow_result_imgs)
163
 
 
164
 
165
+ if ENV=="DEPLOY":
166
+ demo.launch()
167
+ else:
168
+ demo.launch(share=True, server_port=9071)
src/demo.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from .model import ConditionalUNet
6
+ from huggingface_hub import hf_hub_download
7
+
8
+
9
+ def load_models(ENV,device):
10
+ if ENV=="DEPLOY":
11
+ model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",cache_dir="models")
12
+ else:
13
+ model_path = "outputs/diffusion/diffusion_model.pth"
14
+ print("Diff Downloaded!")
15
+ model_diff_standard = ConditionalUNet().to(device)
16
+ model_diff_standard.load_state_dict(torch.load(model_path, map_location=device))
17
+ model_diff_standard.eval()
18
+
19
+ if ENV=="DEPLOY":
20
+ model_path_standard = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",cache_dir="models")
21
+ model_path_localized = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model_localized_noise.pth",cache_dir="models")
22
+ else:
23
+ model_path_standard = "outputs/flow_matching/flow_model.pth"
24
+ model_path_localized = "outputs/flow_matching/flow_model_localized_noise.pth"
25
+ print("Flow Downloaded!")
26
+ model_flow_standard = ConditionalUNet().to(device)
27
+ model_flow_standard.load_state_dict(torch.load(model_path_standard, map_location=device))
28
+ model_flow_standard.eval()
29
+ model_flow_localized = ConditionalUNet().to(device)
30
+ model_flow_localized.load_state_dict(torch.load(model_path_localized, map_location=device))
31
+ model_flow_localized.eval()
32
+
33
+ return model_diff_standard,model_flow_standard,model_flow_localized
34
+ def resize(image,size=(200,200)):
35
+ stretch_near = cv2.resize(image, size, interpolation = cv2.INTER_LINEAR)
36
+ return stretch_near
37
+
38
+
39
+
40
+ def plot_diff(outputs,x,t,noise_pred):
41
+
42
+ if t in [499, 399, 299, 199, 99, 0]:
43
+ step_idx = {499: 6, 399: 7, 299: 8, 199: 9, 99: 10, 0: 11}[t]
44
+ v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy()
45
+ v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
46
+ vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3]
47
+ vel_colored = (vel_colored * 255).astype(np.uint8)
48
+ outputs[step_idx] = resize(vel_colored)
49
+
50
+ outputs[12] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
51
+
52
+ if t in [400, 300, 200, 100, 1, 0]:
53
+ step_idx = {400: 1, 300: 2, 200: 3, 100: 4, 1: 5, 0 :12}[t]
54
+ if t==0:
55
+ outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
56
+ else:
57
+ outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy())
58
+
59
+ return outputs
60
+
61
+ def plot_flow(outputs,i,x,dt,v):
62
+ # Compute velocity magnitude and convert to numpy for visualization
63
+ outputs[12] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
64
+ if i in [10,20,30,40,48,49]: #
65
+ step_idx = {10: 1, 20: 2, 30: 3, 40: 4, 48: 5,49:12}[i] #,
66
+ if i==49:
67
+ outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
68
+ else:
69
+ outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
70
+
71
+ if i in [0,11,21,31,41,49]:
72
+ v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() # Clamp to max value for better contrast
73
+ v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
74
+ vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] # (H,W,3)
75
+ vel_colored = (vel_colored * 255).astype(np.uint8)
76
+ step_idx = {0: 6, 11: 7, 21: 8, 31: 9, 41: 10, 49:11}[i]
77
+ outputs[step_idx] = resize(vel_colored)
78
+ return outputs
src/utils.py CHANGED
@@ -11,4 +11,23 @@ def set_seed(seed):
11
  torch.manual_seed(seed)
12
  torch.cuda.manual_seed(seed)
13
  torch.cuda.manual_seed_all(seed)
14
- torch.backends.cudnn.deterministic = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  torch.manual_seed(seed)
12
  torch.cuda.manual_seed(seed)
13
  torch.cuda.manual_seed_all(seed)
14
+ torch.backends.cudnn.deterministic = True
15
+
16
+ def generate_centered_gaussian_noise(shape=(1, 1, 28, 28), sigma=5.0, mu=0):
17
+ B, C, H, W = shape
18
+ assert C == 1, "only image gray"
19
+
20
+ yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
21
+ yy = yy.to(torch.float32)
22
+ xx = xx.to(torch.float32)
23
+
24
+ center_y, center_x = H / 2, W / 2
25
+ gauss = torch.exp(-((yy - center_y)**2 + (xx - center_x)**2) / (2 * sigma**2))
26
+ gauss = gauss / gauss.max() # Normalization to [0, 1]
27
+ gauss = gauss.unsqueeze(0).unsqueeze(0).expand(B, C, H, W)
28
+
29
+ noise = mu + torch.randn(B, C, H, W) # Noise with mean mu
30
+ localized_noise = noise * gauss + mu * (1 - gauss)
31
+
32
+ return localized_noise
33
+