Spaces:
Running
on
Zero
Running
on
Zero
fix f16
Browse files
app.py
CHANGED
@@ -215,6 +215,11 @@ def run_classifier(image: Image.Image, threshold):
|
|
215 |
img = image.convert('RGBA')
|
216 |
tensor = transform(img).unsqueeze(0)
|
217 |
|
|
|
|
|
|
|
|
|
|
|
218 |
with torch.no_grad():
|
219 |
probits = model(tensor)[0] # type: torch.Tensor
|
220 |
values, indices = probits.cpu().topk(250)
|
@@ -238,6 +243,13 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
|
238 |
target_tag_index = tags[evt.value]
|
239 |
tensor = transform(img).unsqueeze(0)
|
240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
gradients = {}
|
242 |
activations = {}
|
243 |
|
@@ -339,7 +351,10 @@ def process_images(images, threshold):
|
|
339 |
all_results = []
|
340 |
with torch.no_grad():
|
341 |
for batch, filenames in dataloader:
|
342 |
-
|
|
|
|
|
|
|
343 |
probabilities = model(batch)
|
344 |
for i, prob in enumerate(probabilities):
|
345 |
indices = torch.where(prob > threshold)[0]
|
|
|
215 |
img = image.convert('RGBA')
|
216 |
tensor = transform(img).unsqueeze(0)
|
217 |
|
218 |
+
if torch.cuda.is_available():
|
219 |
+
tensor = tensor.to(device, dtype=torch.float16)
|
220 |
+
else:
|
221 |
+
tensor = tensor.to(device)
|
222 |
+
|
223 |
with torch.no_grad():
|
224 |
probits = model(tensor)[0] # type: torch.Tensor
|
225 |
values, indices = probits.cpu().topk(250)
|
|
|
243 |
target_tag_index = tags[evt.value]
|
244 |
tensor = transform(img).unsqueeze(0)
|
245 |
|
246 |
+
if torch.cuda.is_available():
|
247 |
+
tensor = tensor.to(device, dtype=torch.float16)
|
248 |
+
else:
|
249 |
+
tensor = tensor.to(device)
|
250 |
+
|
251 |
+
tensor.requires_grad_()
|
252 |
+
|
253 |
gradients = {}
|
254 |
activations = {}
|
255 |
|
|
|
351 |
all_results = []
|
352 |
with torch.no_grad():
|
353 |
for batch, filenames in dataloader:
|
354 |
+
if torch.cuda.is_available():
|
355 |
+
batch = batch.to(device, dtype=torch.float16)
|
356 |
+
else:
|
357 |
+
batch = batch.to(device)
|
358 |
probabilities = model(batch)
|
359 |
for i, prob in enumerate(probabilities):
|
360 |
indices = torch.where(prob > threshold)[0]
|