multimodalart HF Staff commited on
Commit
cec9c22
·
verified ·
1 Parent(s): 41f3da5

Update comfy/float.py

Browse files
Files changed (1) hide show
  1. comfy/float.py +2 -1
comfy/float.py CHANGED
@@ -57,12 +57,13 @@ def stochastic_rounding(value, dtype, seed=0):
57
  if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
58
  generator = torch.Generator()
59
  generator.manual_seed(seed)
 
60
  output = torch.empty_like(value, dtype=dtype)
61
  num_slices = max(1, (value.numel() / (4096 * 4096)))
62
  slice_size = max(1, round(value.shape[0] / num_slices))
63
  with torch.no_grad():
64
  for i in range(0, value.shape[0], slice_size):
65
- output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype))
66
  return output
67
 
68
  return value.to(dtype=dtype)
 
57
  if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
58
  generator = torch.Generator()
59
  generator.manual_seed(seed)
60
+ generator.to("cuda")
61
  output = torch.empty_like(value, dtype=dtype)
62
  num_slices = max(1, (value.numel() / (4096 * 4096)))
63
  slice_size = max(1, round(value.shape[0] / num_slices))
64
  with torch.no_grad():
65
  for i in range(0, value.shape[0], slice_size):
66
+ output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
67
  return output
68
 
69
  return value.to(dtype=dtype)