CristianLazoQuispe commited on
Commit
74ec8db
·
1 Parent(s): 83e69a0
Files changed (1) hide show
  1. app.py +77 -34
app.py CHANGED
@@ -1,6 +1,8 @@
1
- import torch
2
  import time
 
 
3
  import gradio as gr
 
4
  from src.utils import generate_centered_gaussian_noise
5
  from src.demo import resize,plot_flow,load_models,plot_diff
6
 
@@ -9,59 +11,83 @@ img_shape = (1, 28, 28)
9
  ENV = "DEPLOY"
10
  TIME_SLEEP = 0.05
11
 
 
 
 
 
12
 
13
-
14
- model_diff_standard,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
15
 
16
 
17
  @torch.no_grad()
18
  def generate_diffusion_intermediates_streaming(label):
19
- timesteps = 500
20
- betas = torch.linspace(1e-4, 0.02, timesteps)
21
- alphas = 1.0 - betas
22
- alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
23
- model_diff = model_diff_standard
24
 
25
  x = torch.randn(1, *img_shape).to(device)
26
  y = torch.tensor([label], dtype=torch.long, device=device)
27
 
28
- # Inicial
29
- img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
30
-
31
- # Para mantener la posición de cada imagen
32
  outputs = [None] * 13
33
- yield tuple(outputs)
34
- outputs[0] = resize(img_np)
35
  yield tuple(outputs)
36
  time.sleep(0.2)
37
 
38
  for t in reversed(range(timesteps)):
 
 
39
  t_tensor = torch.full((x.size(0),), t, device=device, dtype=torch.float)
 
 
 
40
  noise_pred = model_diff(x, t_tensor, y)
41
- x = (1 / alphas[t].sqrt()) * (x - noise_pred * betas[t] / (1 - alphas_cumprod[t]).sqrt() )
 
 
 
 
42
  if t > 0:
43
- noise = torch.randn(1, *img_shape).to(device)
44
  v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
45
  x += v.sqrt() * noise
46
  x = x.clamp(-1, 1)
 
 
 
 
 
 
47
 
48
- outputs = plot_diff(outputs,x,t,noise_pred)
49
- if t % 20 == 0 or t in [499, 399, 299, 199, 99, 0,400, 300, 200, 100, 1]:
 
 
 
 
 
50
  yield tuple(outputs)
51
  time.sleep(0.06)
52
 
53
- if ENV=="LOCAL":
54
  time.sleep(TIME_SLEEP)
55
 
 
 
56
  yield tuple(outputs)
57
 
58
 
59
-
 
60
 
61
  @torch.no_grad()
62
- def generate_flow_intermediates_streaming(label,noise_type):
63
- if noise_type=="Localized":
64
- x = generate_centered_gaussian_noise((1, *img_shape)).to(device)
 
 
 
 
65
  model_flow = model_flow_localized
66
  else:
67
  x = torch.randn(1, *img_shape).to(device)
@@ -70,28 +96,46 @@ def generate_flow_intermediates_streaming(label,noise_type):
70
  y = torch.full((1,), label, dtype=torch.long, device=device)
71
  steps = 50
72
  dt = 1.0 / steps
73
-
74
- # Inicial
75
- img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
76
 
77
- # Para mantener la posición de cada imagen
78
  outputs = [None] * 13
79
- yield tuple(outputs)
80
- outputs[0] = resize(img_np)
81
  yield tuple(outputs)
82
  time.sleep(0.2)
83
 
 
 
84
 
85
- for i in range(steps):
86
  t = torch.full((1,), i * dt, device=device)
 
 
 
87
  v = model_flow(x, t, y)
 
 
 
 
88
  x = x + v * dt
89
- outputs = plot_flow(outputs,i,x,dt,v)
 
 
 
 
 
 
 
 
 
 
 
 
90
  if i % 2 == 0:
91
  yield tuple(outputs)
92
- time.sleep(0.15) # sleep to render properly in gradio
93
- if ENV=="LOCAL":
94
  time.sleep(TIME_SLEEP)
 
 
 
95
  yield tuple(outputs)
96
 
97
 
@@ -163,6 +207,5 @@ with gr.Blocks() as demo:
163
 
164
  if ENV=="DEPLOY":
165
  demo.launch()
166
- #demo.launch(share=True, server_port=9071)
167
  else:
168
  demo.launch(share=True, server_port=9071)
 
 
1
  import time
2
+ import torch
3
+ import logging
4
  import gradio as gr
5
+ logging.basicConfig(level=logging.INFO)
6
  from src.utils import generate_centered_gaussian_noise
7
  from src.demo import resize,plot_flow,load_models,plot_diff
8
 
 
11
  ENV = "DEPLOY"
12
  TIME_SLEEP = 0.05
13
 
14
+ timesteps = 500
15
+ betas = torch.linspace(1e-4, 0.02, timesteps)
16
+ alphas = 1.0 - betas
17
+ alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
18
 
19
+ model_diff,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
 
20
 
21
 
22
  @torch.no_grad()
23
  def generate_diffusion_intermediates_streaming(label):
24
+ logging.info("🚀 Starting Diffusion Generation")
25
+ total_start = time.time()
26
+
27
+
 
28
 
29
  x = torch.randn(1, *img_shape).to(device)
30
  y = torch.tensor([label], dtype=torch.long, device=device)
31
 
 
 
 
 
32
  outputs = [None] * 13
33
+ outputs[0] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
 
34
  yield tuple(outputs)
35
  time.sleep(0.2)
36
 
37
  for t in reversed(range(timesteps)):
38
+ step_start = time.time()
39
+
40
  t_tensor = torch.full((x.size(0),), t, device=device, dtype=torch.float)
41
+
42
+ # Forward pass
43
+ model_start = time.time()
44
  noise_pred = model_diff(x, t_tensor, y)
45
+ model_time = time.time() - model_start
46
+
47
+ # Denoising step
48
+ step_compute_start = time.time()
49
+ x = (1 / alphas[t].sqrt()) * (x - noise_pred * betas[t] / (1 - alphas_cumprod[t]).sqrt())
50
  if t > 0:
51
+ noise = torch.randn_like(x)
52
  v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
53
  x += v.sqrt() * noise
54
  x = x.clamp(-1, 1)
55
+ step_compute_time = time.time() - step_compute_start
56
+
57
+ # Plotting
58
+ plot_start = time.time()
59
+ outputs = plot_diff(outputs, x, t, noise_pred)
60
+ plot_time = time.time() - plot_start
61
 
62
+ # Logging
63
+ step_time = time.time() - step_start
64
+ total_time = time.time() - total_start
65
+ if t % 50 == 0 or t in [400, 300, 200, 100, 0]:
66
+ logging.info(f"Diff [t={t:03d}] total={total_time:.3f}s | total_step={step_time:.3f}s | model={model_time:.3f}s | step={step_compute_time:.3f}s | plot={plot_time:.3f}s")
67
+
68
+ if t % 20 == 0 or t in [499, 399, 299, 199, 99, 0, 400, 300, 200, 100, 1]:
69
  yield tuple(outputs)
70
  time.sleep(0.06)
71
 
72
+ if ENV == "LOCAL":
73
  time.sleep(TIME_SLEEP)
74
 
75
+ total_time = time.time() - total_start
76
+ logging.info(f" Finished diffusion in {total_time:.2f}s")
77
  yield tuple(outputs)
78
 
79
 
80
+ import logging
81
+ logging.basicConfig(level=logging.INFO)
82
 
83
  @torch.no_grad()
84
+ def generate_flow_intermediates_streaming(label, noise_type):
85
+ logging.info("🚀 Starting Flow Matching Generation")
86
+ total_start = time.time()
87
+
88
+ # Select noise and model
89
+ if noise_type == "Localized":
90
+ x = generate_centered_gaussian_noise((1, *img_shape)).to(device)
91
  model_flow = model_flow_localized
92
  else:
93
  x = torch.randn(1, *img_shape).to(device)
 
96
  y = torch.full((1,), label, dtype=torch.long, device=device)
97
  steps = 50
98
  dt = 1.0 / steps
 
 
 
99
 
 
100
  outputs = [None] * 13
101
+ outputs[0] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
 
102
  yield tuple(outputs)
103
  time.sleep(0.2)
104
 
105
+ for i in range(steps):
106
+ step_start = time.time()
107
 
 
108
  t = torch.full((1,), i * dt, device=device)
109
+
110
+ # Forward pass
111
+ model_start = time.time()
112
  v = model_flow(x, t, y)
113
+ model_time = time.time() - model_start
114
+
115
+ # Flow step
116
+ flow_step_start = time.time()
117
  x = x + v * dt
118
+ flow_step_time = time.time() - flow_step_start
119
+
120
+ # Plotting
121
+ plot_start = time.time()
122
+ outputs = plot_flow(outputs, i, x, dt, v)
123
+ plot_time = time.time() - plot_start
124
+
125
+ # Logging
126
+ step_time = time.time() - step_start
127
+ total_time = time.time() - total_start
128
+ if i % 10 == 0 or i in [0, 25, 49]:
129
+ logging.info(f"Flow [step={i:02d}] total={total_time:.3f}s | total_step={step_time:.3f}s | model={model_time:.3f}s | step={flow_step_time:.3f}s | plot={plot_time:.3f}s")
130
+
131
  if i % 2 == 0:
132
  yield tuple(outputs)
133
+ time.sleep(0.15)
134
+ if ENV == "LOCAL":
135
  time.sleep(TIME_SLEEP)
136
+
137
+ total_time = time.time() - total_start
138
+ logging.info(f"Finished flow matching in {total_time:.2f}s")
139
  yield tuple(outputs)
140
 
141
 
 
207
 
208
  if ENV=="DEPLOY":
209
  demo.launch()
 
210
  else:
211
  demo.launch(share=True, server_port=9071)