cdnuts commited on
Commit
6fb195f
·
verified ·
1 Parent(s): 22d958d
Files changed (1) hide show
  1. app.py +16 -1
app.py CHANGED
@@ -215,6 +215,11 @@ def run_classifier(image: Image.Image, threshold):
215
  img = image.convert('RGBA')
216
  tensor = transform(img).unsqueeze(0)
217
 
 
 
 
 
 
218
  with torch.no_grad():
219
  probits = model(tensor)[0] # type: torch.Tensor
220
  values, indices = probits.cpu().topk(250)
@@ -238,6 +243,13 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
238
  target_tag_index = tags[evt.value]
239
  tensor = transform(img).unsqueeze(0)
240
 
 
 
 
 
 
 
 
241
  gradients = {}
242
  activations = {}
243
 
@@ -339,7 +351,10 @@ def process_images(images, threshold):
339
  all_results = []
340
  with torch.no_grad():
341
  for batch, filenames in dataloader:
342
- batch = batch.to(device)
 
 
 
343
  probabilities = model(batch)
344
  for i, prob in enumerate(probabilities):
345
  indices = torch.where(prob > threshold)[0]
 
215
  img = image.convert('RGBA')
216
  tensor = transform(img).unsqueeze(0)
217
 
218
+ if torch.cuda.is_available():
219
+ tensor = tensor.to(device, dtype=torch.float16)
220
+ else:
221
+ tensor = tensor.to(device)
222
+
223
  with torch.no_grad():
224
  probits = model(tensor)[0] # type: torch.Tensor
225
  values, indices = probits.cpu().topk(250)
 
243
  target_tag_index = tags[evt.value]
244
  tensor = transform(img).unsqueeze(0)
245
 
246
+ if torch.cuda.is_available():
247
+ tensor = tensor.to(device, dtype=torch.float16)
248
+ else:
249
+ tensor = tensor.to(device)
250
+
251
+ tensor.requires_grad_()
252
+
253
  gradients = {}
254
  activations = {}
255
 
 
351
  all_results = []
352
  with torch.no_grad():
353
  for batch, filenames in dataloader:
354
+ if torch.cuda.is_available():
355
+ batch = batch.to(device, dtype=torch.float16)
356
+ else:
357
+ batch = batch.to(device)
358
  probabilities = model(batch)
359
  for i, prob in enumerate(probabilities):
360
  indices = torch.where(prob > threshold)[0]