Spaces:
Configuration error
Configuration error
import torch | |
from comfy import model_management | |
def string_to_dtype(s="none", mode=None): | |
s = s.lower().strip() | |
if s in ["default", "as-is"]: | |
return None | |
elif s in ["auto", "auto (comfy)"]: | |
if mode == "vae": | |
return model_management.vae_device() | |
elif mode == "text_encoder": | |
return model_management.text_encoder_dtype() | |
elif mode == "unet": | |
return model_management.unet_dtype() | |
else: | |
raise NotImplementedError(f"Unknown dtype mode '{mode}'") | |
elif s in ["none", "auto (hf)", "auto (hf/bnb)"]: | |
return None | |
elif s in ["fp32", "float32", "float"]: | |
return torch.float32 | |
elif s in ["bf16", "bfloat16"]: | |
return torch.bfloat16 | |
elif s in ["fp16", "float16", "half"]: | |
return torch.float16 | |
elif "fp8" in s or "float8" in s: | |
if "e5m2" in s: | |
return torch.float8_e5m2 | |
elif "e4m3" in s: | |
return torch.float8_e4m3fn | |
else: | |
raise NotImplementedError(f"Unknown 8bit dtype '{s}'") | |
elif "bnb" in s: | |
assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'" | |
return s | |
elif s is None: | |
return None | |
else: | |
raise NotImplementedError(f"Unknown dtype '{s}'") |