Commit
·
74ec8db
1
Parent(s):
83e69a0
logging
Browse files
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
-
import torch
|
2 |
import time
|
|
|
|
|
3 |
import gradio as gr
|
|
|
4 |
from src.utils import generate_centered_gaussian_noise
|
5 |
from src.demo import resize,plot_flow,load_models,plot_diff
|
6 |
|
@@ -9,59 +11,83 @@ img_shape = (1, 28, 28)
|
|
9 |
ENV = "DEPLOY"
|
10 |
TIME_SLEEP = 0.05
|
11 |
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
model_diff_standard,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
|
15 |
|
16 |
|
17 |
@torch.no_grad()
|
18 |
def generate_diffusion_intermediates_streaming(label):
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
model_diff = model_diff_standard
|
24 |
|
25 |
x = torch.randn(1, *img_shape).to(device)
|
26 |
y = torch.tensor([label], dtype=torch.long, device=device)
|
27 |
|
28 |
-
# Inicial
|
29 |
-
img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
|
30 |
-
|
31 |
-
# Para mantener la posición de cada imagen
|
32 |
outputs = [None] * 13
|
33 |
-
|
34 |
-
outputs[0] = resize(img_np)
|
35 |
yield tuple(outputs)
|
36 |
time.sleep(0.2)
|
37 |
|
38 |
for t in reversed(range(timesteps)):
|
|
|
|
|
39 |
t_tensor = torch.full((x.size(0),), t, device=device, dtype=torch.float)
|
|
|
|
|
|
|
40 |
noise_pred = model_diff(x, t_tensor, y)
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
if t > 0:
|
43 |
-
noise = torch.
|
44 |
v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
|
45 |
x += v.sqrt() * noise
|
46 |
x = x.clamp(-1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
yield tuple(outputs)
|
51 |
time.sleep(0.06)
|
52 |
|
53 |
-
if ENV=="LOCAL":
|
54 |
time.sleep(TIME_SLEEP)
|
55 |
|
|
|
|
|
56 |
yield tuple(outputs)
|
57 |
|
58 |
|
59 |
-
|
|
|
60 |
|
61 |
@torch.no_grad()
|
62 |
-
def generate_flow_intermediates_streaming(label,noise_type):
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
model_flow = model_flow_localized
|
66 |
else:
|
67 |
x = torch.randn(1, *img_shape).to(device)
|
@@ -70,28 +96,46 @@ def generate_flow_intermediates_streaming(label,noise_type):
|
|
70 |
y = torch.full((1,), label, dtype=torch.long, device=device)
|
71 |
steps = 50
|
72 |
dt = 1.0 / steps
|
73 |
-
|
74 |
-
# Inicial
|
75 |
-
img_np = ((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()
|
76 |
|
77 |
-
# Para mantener la posición de cada imagen
|
78 |
outputs = [None] * 13
|
79 |
-
|
80 |
-
outputs[0] = resize(img_np)
|
81 |
yield tuple(outputs)
|
82 |
time.sleep(0.2)
|
83 |
|
|
|
|
|
84 |
|
85 |
-
for i in range(steps):
|
86 |
t = torch.full((1,), i * dt, device=device)
|
|
|
|
|
|
|
87 |
v = model_flow(x, t, y)
|
|
|
|
|
|
|
|
|
88 |
x = x + v * dt
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
if i % 2 == 0:
|
91 |
yield tuple(outputs)
|
92 |
-
time.sleep(0.15)
|
93 |
-
if ENV=="LOCAL":
|
94 |
time.sleep(TIME_SLEEP)
|
|
|
|
|
|
|
95 |
yield tuple(outputs)
|
96 |
|
97 |
|
@@ -163,6 +207,5 @@ with gr.Blocks() as demo:
|
|
163 |
|
164 |
if ENV=="DEPLOY":
|
165 |
demo.launch()
|
166 |
-
#demo.launch(share=True, server_port=9071)
|
167 |
else:
|
168 |
demo.launch(share=True, server_port=9071)
|
|
|
|
|
1 |
import time
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
import gradio as gr
|
5 |
+
logging.basicConfig(level=logging.INFO)
|
6 |
from src.utils import generate_centered_gaussian_noise
|
7 |
from src.demo import resize,plot_flow,load_models,plot_diff
|
8 |
|
|
|
11 |
ENV = "DEPLOY"
|
12 |
TIME_SLEEP = 0.05
|
13 |
|
14 |
+
timesteps = 500
|
15 |
+
betas = torch.linspace(1e-4, 0.02, timesteps)
|
16 |
+
alphas = 1.0 - betas
|
17 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
|
18 |
|
19 |
+
model_diff,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
|
|
|
20 |
|
21 |
|
22 |
@torch.no_grad()
|
23 |
def generate_diffusion_intermediates_streaming(label):
|
24 |
+
logging.info("🚀 Starting Diffusion Generation")
|
25 |
+
total_start = time.time()
|
26 |
+
|
27 |
+
|
|
|
28 |
|
29 |
x = torch.randn(1, *img_shape).to(device)
|
30 |
y = torch.tensor([label], dtype=torch.long, device=device)
|
31 |
|
|
|
|
|
|
|
|
|
32 |
outputs = [None] * 13
|
33 |
+
outputs[0] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
|
|
|
34 |
yield tuple(outputs)
|
35 |
time.sleep(0.2)
|
36 |
|
37 |
for t in reversed(range(timesteps)):
|
38 |
+
step_start = time.time()
|
39 |
+
|
40 |
t_tensor = torch.full((x.size(0),), t, device=device, dtype=torch.float)
|
41 |
+
|
42 |
+
# Forward pass
|
43 |
+
model_start = time.time()
|
44 |
noise_pred = model_diff(x, t_tensor, y)
|
45 |
+
model_time = time.time() - model_start
|
46 |
+
|
47 |
+
# Denoising step
|
48 |
+
step_compute_start = time.time()
|
49 |
+
x = (1 / alphas[t].sqrt()) * (x - noise_pred * betas[t] / (1 - alphas_cumprod[t]).sqrt())
|
50 |
if t > 0:
|
51 |
+
noise = torch.randn_like(x)
|
52 |
v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
|
53 |
x += v.sqrt() * noise
|
54 |
x = x.clamp(-1, 1)
|
55 |
+
step_compute_time = time.time() - step_compute_start
|
56 |
+
|
57 |
+
# Plotting
|
58 |
+
plot_start = time.time()
|
59 |
+
outputs = plot_diff(outputs, x, t, noise_pred)
|
60 |
+
plot_time = time.time() - plot_start
|
61 |
|
62 |
+
# Logging
|
63 |
+
step_time = time.time() - step_start
|
64 |
+
total_time = time.time() - total_start
|
65 |
+
if t % 50 == 0 or t in [400, 300, 200, 100, 0]:
|
66 |
+
logging.info(f"Diff [t={t:03d}] total={total_time:.3f}s | total_step={step_time:.3f}s | model={model_time:.3f}s | step={step_compute_time:.3f}s | plot={plot_time:.3f}s")
|
67 |
+
|
68 |
+
if t % 20 == 0 or t in [499, 399, 299, 199, 99, 0, 400, 300, 200, 100, 1]:
|
69 |
yield tuple(outputs)
|
70 |
time.sleep(0.06)
|
71 |
|
72 |
+
if ENV == "LOCAL":
|
73 |
time.sleep(TIME_SLEEP)
|
74 |
|
75 |
+
total_time = time.time() - total_start
|
76 |
+
logging.info(f" Finished diffusion in {total_time:.2f}s")
|
77 |
yield tuple(outputs)
|
78 |
|
79 |
|
80 |
+
import logging
|
81 |
+
logging.basicConfig(level=logging.INFO)
|
82 |
|
83 |
@torch.no_grad()
|
84 |
+
def generate_flow_intermediates_streaming(label, noise_type):
|
85 |
+
logging.info("🚀 Starting Flow Matching Generation")
|
86 |
+
total_start = time.time()
|
87 |
+
|
88 |
+
# Select noise and model
|
89 |
+
if noise_type == "Localized":
|
90 |
+
x = generate_centered_gaussian_noise((1, *img_shape)).to(device)
|
91 |
model_flow = model_flow_localized
|
92 |
else:
|
93 |
x = torch.randn(1, *img_shape).to(device)
|
|
|
96 |
y = torch.full((1,), label, dtype=torch.long, device=device)
|
97 |
steps = 50
|
98 |
dt = 1.0 / steps
|
|
|
|
|
|
|
99 |
|
|
|
100 |
outputs = [None] * 13
|
101 |
+
outputs[0] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
|
|
|
102 |
yield tuple(outputs)
|
103 |
time.sleep(0.2)
|
104 |
|
105 |
+
for i in range(steps):
|
106 |
+
step_start = time.time()
|
107 |
|
|
|
108 |
t = torch.full((1,), i * dt, device=device)
|
109 |
+
|
110 |
+
# Forward pass
|
111 |
+
model_start = time.time()
|
112 |
v = model_flow(x, t, y)
|
113 |
+
model_time = time.time() - model_start
|
114 |
+
|
115 |
+
# Flow step
|
116 |
+
flow_step_start = time.time()
|
117 |
x = x + v * dt
|
118 |
+
flow_step_time = time.time() - flow_step_start
|
119 |
+
|
120 |
+
# Plotting
|
121 |
+
plot_start = time.time()
|
122 |
+
outputs = plot_flow(outputs, i, x, dt, v)
|
123 |
+
plot_time = time.time() - plot_start
|
124 |
+
|
125 |
+
# Logging
|
126 |
+
step_time = time.time() - step_start
|
127 |
+
total_time = time.time() - total_start
|
128 |
+
if i % 10 == 0 or i in [0, 25, 49]:
|
129 |
+
logging.info(f"Flow [step={i:02d}] total={total_time:.3f}s | total_step={step_time:.3f}s | model={model_time:.3f}s | step={flow_step_time:.3f}s | plot={plot_time:.3f}s")
|
130 |
+
|
131 |
if i % 2 == 0:
|
132 |
yield tuple(outputs)
|
133 |
+
time.sleep(0.15)
|
134 |
+
if ENV == "LOCAL":
|
135 |
time.sleep(TIME_SLEEP)
|
136 |
+
|
137 |
+
total_time = time.time() - total_start
|
138 |
+
logging.info(f"Finished flow matching in {total_time:.2f}s")
|
139 |
yield tuple(outputs)
|
140 |
|
141 |
|
|
|
207 |
|
208 |
if ENV=="DEPLOY":
|
209 |
demo.launch()
|
|
|
210 |
else:
|
211 |
demo.launch(share=True, server_port=9071)
|