johnbridges commited on
Commit
0a090d6
·
verified ·
1 Parent(s): a488238

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -34,7 +34,7 @@ def pick_dtype(device: str) -> torch.dtype:
34
  return torch.bfloat16 if major >= 8 else torch.float16 # Ampere+ -> bf16
35
  if device == "mps":
36
  return torch.float16
37
- return torch.float16 # CPU
38
 
39
  def move_to_device(batch, device: str):
40
  if isinstance(batch, dict):
 
34
  return torch.bfloat16 if major >= 8 else torch.float16 # Ampere+ -> bf16
35
  if device == "mps":
36
  return torch.float16
37
+ return torch.float32 # CPU
38
 
39
  def move_to_device(batch, device: str):
40
  if isinstance(batch, dict):