CristianLazoQuispe commited on
Commit
d36e58a
1 Parent(s): 29ff960

streaming outputs

Browse files
Files changed (1) hide show
  1. app.py +82 -42
app.py CHANGED
@@ -7,8 +7,10 @@ import gradio as gr
7
  import matplotlib.pyplot as plt
8
  from src.model import ConditionalUNet
9
  from huggingface_hub import hf_hub_download
10
-
11
  device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
 
 
12
  img_shape = (1, 28, 28)
13
 
14
 
@@ -33,17 +35,24 @@ model_flow.load_state_dict(torch.load(model_path, map_location=device))
33
  model_flow.eval()
34
 
35
  @torch.no_grad()
36
- def generate_diffusion_intermediates(label):
37
  timesteps = 500
38
- img_shape = (1, 28, 28)
39
  betas = torch.linspace(1e-4, 0.02, timesteps)
40
  alphas = 1.0 - betas
41
  alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
42
 
43
  x = torch.randn(1, *img_shape).to(device)
44
  y = torch.tensor([label], dtype=torch.long, device=device)
45
- noise_magnitudes = []
46
- intermediates = [resize(((x + 1) / 2.0)[0][0].clamp(0, 1).cpu().numpy())]
 
 
 
 
 
 
 
 
47
 
48
  for t in reversed(range(timesteps)):
49
  t_tensor = torch.full((x.size(0),), t, device=device, dtype=torch.float)
@@ -53,22 +62,26 @@ def generate_diffusion_intermediates(label):
53
  noise = torch.randn(1, *img_shape).to(device)
54
  v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
55
  x += v.sqrt() * noise
56
-
57
  x = x.clamp(-1, 1)
58
- if t in [400, 300, 200, 100,0]:
59
- #print("t:",t)
60
- img_np = ((x + 1) / 2)[0, 0].cpu().numpy()
61
- intermediates.append(resize(img_np))
62
 
63
- if t in [499, 399, 299, 199,99,0]:
64
- # Compute velocity magnitude and convert to numpy for visualization
65
- v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy() # Clamp to max value for better contrast
 
66
  v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
67
- vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] # (H,W,3)
68
  vel_colored = (vel_colored * 255).astype(np.uint8)
69
- noise_magnitudes.append(resize(vel_colored, (100, 100)))
 
 
70
 
71
- return intermediates+noise_magnitudes
 
 
 
 
 
 
72
 
73
 
74
  def generate_localized_noise(shape, radius=5):
@@ -89,7 +102,7 @@ def generate_localized_noise(shape, radius=5):
89
 
90
 
91
  @torch.no_grad()
92
- def generate_flow_intermediates(label):
93
  x = torch.randn(1, *img_shape).to(device)
94
  #x = generate_localized_noise((1, 1, 28, 28), radius=12).to(device)
95
  y = torch.full((1,), label, dtype=torch.long, device=device)
@@ -98,23 +111,43 @@ def generate_flow_intermediates(label):
98
 
99
  images = [(x + 1) / 2.0] # initial noise
100
  vel_magnitudes = []
 
 
 
 
 
 
 
 
 
 
 
 
101
  for i in range(steps):
102
 
103
  t = torch.full((1,), i * dt, device=device)
104
  v = model_flow(x, t, y)
105
  x = x + v * dt
106
 
107
- if i in [10,20,30,40,49]:
108
- images.append((x + 1) / 2.0)
 
 
 
 
109
  # Compute velocity magnitude and convert to numpy for visualization
110
- if i in [0,10,20,30,40,49]:
111
  v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() # Clamp to max value for better contrast
112
  v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
113
  vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] # (H,W,3)
114
  vel_colored = (vel_colored * 255).astype(np.uint8)
115
- vel_magnitudes.append(resize(vel_colored, (100, 100)))
 
 
 
 
 
116
 
117
- return [resize(images[0][0][0].clamp(0, 1).cpu().numpy())]+[resize(img[0][0].clamp(0, 1).cpu().numpy()) for img in images[-5:]]+vel_magnitudes
118
 
119
  with gr.Blocks() as demo:
120
  gr.Markdown("# Conditional MNIST Generation: Diffusion vs Flow Matching")
@@ -124,24 +157,28 @@ with gr.Blocks() as demo:
124
  btn_d = gr.Button("Generate")
125
  with gr.Row():
126
  outs_d = [
127
- gr.Image(label="Noise"),
128
- gr.Image(label="Diffusion t=400"),
129
- gr.Image(label="Diffusion t=300"),
130
- gr.Image(label="Diffusion t=200"),
131
- gr.Image(label="Diffusion t=100"),
132
- gr.Image(label="Diffusion t=0"),
133
  ]
134
  with gr.Row():
135
  #400, 300, 200, 100,0
136
- flow_noise_imgs = [
137
- gr.Image(label="Noise pred t=500"),
138
- gr.Image(label="Noise pred t=400"),
139
- gr.Image(label="Noise pred t=300"),
140
- gr.Image(label="Noise pred t=200"),
141
- gr.Image(label="Noise pred t=100"),
142
- gr.Image(label="Noise pred t=0")
143
  ]
144
- btn_d.click(fn=generate_diffusion_intermediates, inputs=label_d, outputs=outs_d+flow_noise_imgs)
 
 
 
 
145
 
146
  with gr.Tab("Flow Matching"):
147
  label_f = gr.Slider(0, 9, step=1, label="Digit Label")
@@ -153,7 +190,7 @@ with gr.Blocks() as demo:
153
  gr.Image(label="Flow step=20"),
154
  gr.Image(label="Flow step=30"),
155
  gr.Image(label="Flow step=40"),
156
- gr.Image(label="Flow step=49"),
157
  ]
158
  with gr.Row():
159
  #100,200,300,400,499
@@ -163,10 +200,13 @@ with gr.Blocks() as demo:
163
  gr.Image(label="Velocity step=20"),
164
  gr.Image(label="Velocity step=30"),
165
  gr.Image(label="Velocity step=40"),
166
- gr.Image(label="Velocity step=49")
167
  ]
 
 
 
 
 
168
 
169
- btn_f.click(fn=generate_flow_intermediates, inputs=label_f, outputs=outs_f+flow_vel_imgs)
170
-
171
- demo.launch()
172
- #demo.launch(share=False, server_port=9070)
 
7
  import matplotlib.pyplot as plt
8
  from src.model import ConditionalUNet
9
  from huggingface_hub import hf_hub_download
10
+ import time
11
  device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
12
+
13
+ device = 'cpu'
14
  img_shape = (1, 28, 28)
15
 
16
 
 
35
  model_flow.eval()
36
 
37
  @torch.no_grad()
38
+ def generate_diffusion_intermediates_streaming(label):
39
  timesteps = 500
 
40
  betas = torch.linspace(1e-4, 0.02, timesteps)
41
  alphas = 1.0 - betas
42
  alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
43
 
44
  x = torch.randn(1, *img_shape).to(device)
45
  y = torch.tensor([label], dtype=torch.long, device=device)
46
+
47
+ # Inicial
48
+ img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
49
+
50
+ # Para mantener la posici贸n de cada imagen
51
+ outputs = [None] * 13
52
+ yield tuple(outputs)
53
+ outputs[0] = resize(img_np)
54
+ yield tuple(outputs)
55
+ #time.sleep(0.5)
56
 
57
  for t in reversed(range(timesteps)):
58
  t_tensor = torch.full((x.size(0),), t, device=device, dtype=torch.float)
 
62
  noise = torch.randn(1, *img_shape).to(device)
63
  v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
64
  x += v.sqrt() * noise
 
65
  x = x.clamp(-1, 1)
 
 
 
 
66
 
67
+
68
+ if t in [499, 399, 299, 199, 99, 0]:
69
+ step_idx = {499: 6, 399: 7, 299: 8, 199: 9, 99: 10, 0: 11}[t]
70
+ v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy()
71
  v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
72
+ vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3]
73
  vel_colored = (vel_colored * 255).astype(np.uint8)
74
+ outputs[step_idx] = resize(vel_colored)
75
+ yield tuple(outputs)
76
+ time.sleep(0.5)
77
 
78
+ if t in [400, 300, 200, 100, 1, 0]:
79
+ step_idx = {400: 1, 300: 2, 200: 3, 100: 4, 1: 5, 0 :12}[t]
80
+ if t==0:
81
+ outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(400,400))
82
+ else:
83
+ outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy())
84
+ yield tuple(outputs)
85
 
86
 
87
  def generate_localized_noise(shape, radius=5):
 
102
 
103
 
104
  @torch.no_grad()
105
+ def generate_flow_intermediates_streaming(label):
106
  x = torch.randn(1, *img_shape).to(device)
107
  #x = generate_localized_noise((1, 1, 28, 28), radius=12).to(device)
108
  y = torch.full((1,), label, dtype=torch.long, device=device)
 
111
 
112
  images = [(x + 1) / 2.0] # initial noise
113
  vel_magnitudes = []
114
+
115
+ # Inicial
116
+ img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
117
+
118
+ # Para mantener la posici贸n de cada imagen
119
+ outputs = [None] * 13
120
+ yield tuple(outputs)
121
+ outputs[0] = resize(img_np)
122
+ yield tuple(outputs)
123
+ #time.sleep(0.5)
124
+
125
+
126
  for i in range(steps):
127
 
128
  t = torch.full((1,), i * dt, device=device)
129
  v = model_flow(x, t, y)
130
  x = x + v * dt
131
 
132
+ if i in [10,20,30,40,48,49]:
133
+ #images.append((x + 1) / 2.0)
134
+ step_idx = {10: 1, 20: 2, 30: 3, 40: 4, 48: 5, 49:12}[i]
135
+ outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
136
+ yield tuple(outputs)
137
+
138
  # Compute velocity magnitude and convert to numpy for visualization
139
+ if i in [0,11,21,31,41,49]:
140
  v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() # Clamp to max value for better contrast
141
  v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
142
  vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] # (H,W,3)
143
  vel_colored = (vel_colored * 255).astype(np.uint8)
144
+ step_idx = {0: 6, 11: 7, 21: 8, 31: 9, 41: 10, 49:11}[i]
145
+ if i==49:
146
+ outputs[step_idx] = resize(vel_colored, (400, 400))
147
+ else:
148
+ outputs[step_idx] = resize(vel_colored)
149
+ yield tuple(outputs)
150
 
 
151
 
152
  with gr.Blocks() as demo:
153
  gr.Markdown("# Conditional MNIST Generation: Diffusion vs Flow Matching")
 
157
  btn_d = gr.Button("Generate")
158
  with gr.Row():
159
  outs_d = [
160
+ gr.Image(label="Noise",streaming=True),
161
+ gr.Image(label="Diffusion t=400",streaming=True),
162
+ gr.Image(label="Diffusion t=300",streaming=True),
163
+ gr.Image(label="Diffusion t=200",streaming=True),
164
+ gr.Image(label="Diffusion t=100",streaming=True),
165
+ gr.Image(label="Diffusion t=1",streaming=True),
166
  ]
167
  with gr.Row():
168
  #400, 300, 200, 100,0
169
+ diff_noise_imgs = [
170
+ gr.Image(label="Noise pred t=500",streaming=True),
171
+ gr.Image(label="Noise pred t=400",streaming=True),
172
+ gr.Image(label="Noise pred t=300",streaming=True),
173
+ gr.Image(label="Noise pred t=200",streaming=True),
174
+ gr.Image(label="Noise pred t=100",streaming=True),
175
+ gr.Image(label="Noise pred t=1",streaming=True),
176
  ]
177
+ with gr.Row():
178
+ diff_result_imgs = [
179
+ gr.Image(label="Diffusion t=0",streaming=True),
180
+ ]
181
+ btn_d.click(fn=generate_diffusion_intermediates_streaming, inputs=label_d, outputs=outs_d+diff_noise_imgs+diff_result_imgs)
182
 
183
  with gr.Tab("Flow Matching"):
184
  label_f = gr.Slider(0, 9, step=1, label="Digit Label")
 
190
  gr.Image(label="Flow step=20"),
191
  gr.Image(label="Flow step=30"),
192
  gr.Image(label="Flow step=40"),
193
+ gr.Image(label="Flow step=48"),
194
  ]
195
  with gr.Row():
196
  #100,200,300,400,499
 
200
  gr.Image(label="Velocity step=20"),
201
  gr.Image(label="Velocity step=30"),
202
  gr.Image(label="Velocity step=40"),
203
+ gr.Image(label="Velocity step=48")
204
  ]
205
+ with gr.Row():
206
+ flow_result_imgs = [
207
+ gr.Image(label="Flow step=49",streaming=True),
208
+ ]
209
+ btn_f.click(fn=generate_flow_intermediates_streaming, inputs=label_f, outputs=outs_f+flow_vel_imgs+flow_result_imgs)
210
 
211
+ #demo.launch()
212
+ demo.launch(share=False, server_port=9071)