reset fix
Browse files- models/region_diffusion.py +4 -2
- utils/attention_utils.py +1 -0
models/region_diffusion.py
CHANGED
|
@@ -285,8 +285,10 @@ class RegionDiffusion(nn.Module):
|
|
| 285 |
We reset attention maps because we append them while getting hooks
|
| 286 |
to visualize attention maps for every step.
|
| 287 |
"""
|
| 288 |
-
for key in self.
|
| 289 |
-
self.
|
|
|
|
|
|
|
| 290 |
|
| 291 |
def register_evaluation_hooks(self):
|
| 292 |
r"""Function for registering hooks during evaluation.
|
|
|
|
| 285 |
We reset attention maps because we append them while getting hooks
|
| 286 |
to visualize attention maps for every step.
|
| 287 |
"""
|
| 288 |
+
for key in self.selfattn_maps:
|
| 289 |
+
self.selfattn_maps[key] = []
|
| 290 |
+
for key in self.crossattn_maps:
|
| 291 |
+
self.crossattn_maps[key] = []
|
| 292 |
|
| 293 |
def register_evaluation_hooks(self):
|
| 294 |
r"""Function for registering hooks during evaluation.
|
utils/attention_utils.py
CHANGED
|
@@ -123,6 +123,7 @@ def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=N
|
|
| 123 |
dtype='uint8').reshape((height, width, 3))
|
| 124 |
|
| 125 |
fig.tight_layout()
|
|
|
|
| 126 |
return img
|
| 127 |
|
| 128 |
|
|
|
|
| 123 |
dtype='uint8').reshape((height, width, 3))
|
| 124 |
|
| 125 |
fig.tight_layout()
|
| 126 |
+
plt.close()
|
| 127 |
return img
|
| 128 |
|
| 129 |
|