Update comfy/float.py
Browse files- comfy/float.py +2 -1
comfy/float.py
CHANGED
@@ -57,12 +57,13 @@ def stochastic_rounding(value, dtype, seed=0):
|
|
57 |
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
58 |
generator = torch.Generator()
|
59 |
generator.manual_seed(seed)
|
|
|
60 |
output = torch.empty_like(value, dtype=dtype)
|
61 |
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
62 |
slice_size = max(1, round(value.shape[0] / num_slices))
|
63 |
with torch.no_grad():
|
64 |
for i in range(0, value.shape[0], slice_size):
|
65 |
-
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype))
|
66 |
return output
|
67 |
|
68 |
return value.to(dtype=dtype)
|
|
|
57 |
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
58 |
generator = torch.Generator()
|
59 |
generator.manual_seed(seed)
|
60 |
+
generator.to("cuda")
|
61 |
output = torch.empty_like(value, dtype=dtype)
|
62 |
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
63 |
slice_size = max(1, round(value.shape[0] / num_slices))
|
64 |
with torch.no_grad():
|
65 |
for i in range(0, value.shape[0], slice_size):
|
66 |
+
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
67 |
return output
|
68 |
|
69 |
return value.to(dtype=dtype)
|