yyfz233 commited on
Commit
0a1ff26
·
1 Parent(s): 00143c1

Change to bfloat16

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -301,7 +301,7 @@ def run_model(target_dir, model) -> dict:
301
 
302
  # 3. Infer
303
  print("Running model inference...")
304
- dtype = torch.float16
305
  with torch.no_grad():
306
  with torch.amp.autocast('cuda', dtype=dtype):
307
  predictions = model(imgs[None]) # Add batch dimension
 
301
 
302
  # 3. Infer
303
  print("Running model inference...")
304
+ dtype = torch.bfloat16
305
  with torch.no_grad():
306
  with torch.amp.autocast('cuda', dtype=dtype):
307
  predictions = model(imgs[None]) # Add batch dimension