Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
6c1d070
1
Parent(s):
023c7dd
support higher res/lower res sampling than training time
Browse files
score_sde/models/ncsnpp_generator_adagn.py
CHANGED
|
@@ -379,7 +379,8 @@ class NCSNpp(nn.Module):
|
|
| 379 |
#print(hs[-1].shape, temb.shape, zemb.shape, type(modules[m_idx]))
|
| 380 |
h = modules[m_idx](hs[-1], temb, zemb)
|
| 381 |
m_idx += 1
|
| 382 |
-
if
|
|
|
|
| 383 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
| 384 |
h = modules[m_idx](h, cond, cond_mask)
|
| 385 |
else:
|
|
@@ -415,6 +416,7 @@ class NCSNpp(nn.Module):
|
|
| 415 |
h = hs[-1]
|
| 416 |
h = modules[m_idx](h, temb, zemb)
|
| 417 |
m_idx += 1
|
|
|
|
| 418 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
| 419 |
h = modules[m_idx](h, cond, cond_mask)
|
| 420 |
else:
|
|
@@ -431,7 +433,8 @@ class NCSNpp(nn.Module):
|
|
| 431 |
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb, zemb)
|
| 432 |
m_idx += 1
|
| 433 |
|
| 434 |
-
if h.shape[-1] in self.attn_resolutions:
|
|
|
|
| 435 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
| 436 |
h = modules[m_idx](h, cond, cond_mask)
|
| 437 |
else:
|
|
|
|
| 379 |
#print(hs[-1].shape, temb.shape, zemb.shape, type(modules[m_idx]))
|
| 380 |
h = modules[m_idx](hs[-1], temb, zemb)
|
| 381 |
m_idx += 1
|
| 382 |
+
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock, layers.AttnBlock):
|
| 383 |
+
#if h.shape[-1] in self.attn_resolutions:
|
| 384 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
| 385 |
h = modules[m_idx](h, cond, cond_mask)
|
| 386 |
else:
|
|
|
|
| 416 |
h = hs[-1]
|
| 417 |
h = modules[m_idx](h, temb, zemb)
|
| 418 |
m_idx += 1
|
| 419 |
+
|
| 420 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
| 421 |
h = modules[m_idx](h, cond, cond_mask)
|
| 422 |
else:
|
|
|
|
| 433 |
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb, zemb)
|
| 434 |
m_idx += 1
|
| 435 |
|
| 436 |
+
#if h.shape[-1] in self.attn_resolutions:
|
| 437 |
+
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock, layers.AttnBlock):
|
| 438 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
| 439 |
h = modules[m_idx](h, cond, cond_mask)
|
| 440 |
else:
|