Spaces:
Sleeping
Sleeping
timesteps
Browse files
app.py
CHANGED
|
@@ -222,7 +222,7 @@ def sample_ddpm(n_sample, context, save_rate=20):
|
|
| 222 |
return samples, intermediate
|
| 223 |
|
| 224 |
@torch.no_grad()
|
| 225 |
-
def sample_ddpm_context(n_sample, context, save_rate=20):
|
| 226 |
# x_T ~ N(0, 1), sample initial noise
|
| 227 |
samples = torch.randn(n_sample, 3, height, height).to(device)
|
| 228 |
|
|
@@ -257,7 +257,7 @@ def greet(input):
|
|
| 257 |
ctx = torch.from_numpy(mtx_2d).to(device=device).float()
|
| 258 |
|
| 259 |
#samples, intermediate = sample_ddim_context(32, ctx, n=steps)
|
| 260 |
-
samples, intermediate = sample_ddpm_context(32,
|
| 261 |
|
| 262 |
#samples, intermediate = sample_ddim(32, n=steps)
|
| 263 |
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
|
|
|
| 222 |
return samples, intermediate
|
| 223 |
|
| 224 |
@torch.no_grad()
|
| 225 |
+
def sample_ddpm_context(n_sample,timesteps, context, save_rate=20):
|
| 226 |
# x_T ~ N(0, 1), sample initial noise
|
| 227 |
samples = torch.randn(n_sample, 3, height, height).to(device)
|
| 228 |
|
|
|
|
| 257 |
ctx = torch.from_numpy(mtx_2d).to(device=device).float()
|
| 258 |
|
| 259 |
#samples, intermediate = sample_ddim_context(32, ctx, n=steps)
|
| 260 |
+
samples, intermediate = sample_ddpm_context(32, steps, ctx)
|
| 261 |
|
| 262 |
#samples, intermediate = sample_ddim(32, n=steps)
|
| 263 |
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|