Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -68,8 +68,8 @@ def get_processed_features_dino(num_patches, img,use_dummy):
|
|
68 |
desc = aggre_net(features_dino)
|
69 |
norms = torch.linalg.norm(desc, dim=1, keepdim=True)
|
70 |
desc = desc / (norms + 1e-8)
|
71 |
-
desc = desc.cpu()
|
72 |
-
del batch, features_dino
|
73 |
gc.collect()
|
74 |
torch.cuda.empty_cache()
|
75 |
return desc # shape [1, C, num_patches, num_patches]
|
@@ -114,6 +114,7 @@ def update_features(
|
|
114 |
num_patches,
|
115 |
use_dummy
|
116 |
):
|
|
|
117 |
torch.cuda.empty_cache()
|
118 |
"""
|
119 |
Given a PIL image, returns:
|
@@ -138,6 +139,8 @@ def on_select(
|
|
138 |
or_src_img: Image,
|
139 |
sel: gr.SelectData
|
140 |
):
|
|
|
|
|
141 |
# Convert to numpy arrays
|
142 |
src_arr = np.array(or_src_img)
|
143 |
tgt_arr = np.array(or_tgt_img)
|
@@ -194,6 +197,9 @@ extractor_vit.model.eval()
|
|
194 |
|
195 |
upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False)
|
196 |
|
|
|
|
|
|
|
197 |
# βββ Build Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββ
|
198 |
with gr.Blocks() as demo:
|
199 |
# Hidden states to hold features
|
|
|
68 |
desc = aggre_net(features_dino)
|
69 |
norms = torch.linalg.norm(desc, dim=1, keepdim=True)
|
70 |
desc = desc / (norms + 1e-8)
|
71 |
+
desc = desc.cpu()
|
72 |
+
# del batch, features_dino
|
73 |
gc.collect()
|
74 |
torch.cuda.empty_cache()
|
75 |
return desc # shape [1, C, num_patches, num_patches]
|
|
|
114 |
num_patches,
|
115 |
use_dummy
|
116 |
):
|
117 |
+
gc.collect()
|
118 |
torch.cuda.empty_cache()
|
119 |
"""
|
120 |
Given a PIL image, returns:
|
|
|
139 |
or_src_img: Image,
|
140 |
sel: gr.SelectData
|
141 |
):
|
142 |
+
gc.collect()
|
143 |
+
torch.cuda.empty_cache()
|
144 |
# Convert to numpy arrays
|
145 |
src_arr = np.array(or_src_img)
|
146 |
tgt_arr = np.array(or_tgt_img)
|
|
|
197 |
|
198 |
upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False)
|
199 |
|
200 |
+
gc.collect()
|
201 |
+
torch.cuda.empty_cache()
|
202 |
+
|
203 |
# βββ Build Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββ
|
204 |
with gr.Blocks() as demo:
|
205 |
# Hidden states to hold features
|