Spaces:
Configuration error
Configuration error
#shamelessly taken from forge | |
import nodes | |
import folder_paths | |
import bitsandbytes | |
import torch | |
import bitsandbytes as bnb | |
from bitsandbytes.nn.modules import Params4bit, QuantState | |
def functional_linear_4bits(x, weight, bias): | |
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state) | |
out = out.to(x) | |
return out | |
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState: | |
if state is None: | |
return None | |
device = device or state.absmax.device | |
state2 = ( | |
QuantState( | |
absmax=state.state2.absmax.to(device), | |
shape=state.state2.shape, | |
code=state.state2.code.to(device), | |
blocksize=state.state2.blocksize, | |
quant_type=state.state2.quant_type, | |
dtype=state.state2.dtype, | |
) | |
if state.nested | |
else None | |
) | |
return QuantState( | |
absmax=state.absmax.to(device), | |
shape=state.shape, | |
code=state.code.to(device), | |
blocksize=state.blocksize, | |
quant_type=state.quant_type, | |
dtype=state.dtype, | |
offset=state.offset.to(device) if state.nested else None, | |
state2=state2, | |
) | |
class ForgeParams4bit(Params4bit): | |
def to(self, *args, **kwargs): | |
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) | |
if device is not None and device.type == "cuda" and not self.bnb_quantized: | |
return self._quantize(device) | |
else: | |
n = ForgeParams4bit( | |
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), | |
requires_grad=self.requires_grad, | |
quant_state=copy_quant_state(self.quant_state, device), | |
blocksize=self.blocksize, | |
compress_statistics=self.compress_statistics, | |
quant_type=self.quant_type, | |
quant_storage=self.quant_storage, | |
bnb_quantized=self.bnb_quantized, | |
module=self.module | |
) | |
self.module.quant_state = n.quant_state | |
self.data = n.data | |
self.quant_state = n.quant_state | |
return n | |
class ForgeLoader4Bit(torch.nn.Module): | |
def __init__(self, *, device, dtype, quant_type, **kwargs): | |
super().__init__() | |
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) | |
self.weight = None | |
self.quant_state = None | |
self.bias = None | |
self.quant_type = quant_type | |
def _save_to_state_dict(self, destination, prefix, keep_vars): | |
super()._save_to_state_dict(destination, prefix, keep_vars) | |
quant_state = getattr(self.weight, "quant_state", None) | |
if quant_state is not None: | |
for k, v in quant_state.as_dict(packed=True).items(): | |
destination[prefix + "weight." + k] = v if keep_vars else v.detach() | |
return | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
global current_nf4_version | |
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")} | |
if any('bitsandbytes' in k for k in quant_state_keys): | |
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys} | |
self.weight = ForgeParams4bit.from_prequantized( | |
data=state_dict[prefix + 'weight'], | |
quantized_stats=quant_state_dict, | |
requires_grad=False, | |
device=self.dummy.device, | |
module=self | |
) | |
self.quant_state = self.weight.quant_state | |
if prefix + 'bias' in state_dict: | |
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) | |
del self.dummy | |
elif hasattr(self, 'dummy'): | |
if prefix + 'weight' in state_dict: | |
if current_nf4_version == 'v2': | |
print(f'ForgeLoader4Bit: v2') | |
self.weight = ForgeParams4bit( | |
state_dict[prefix + 'weight'].to(self.dummy), | |
requires_grad=False, | |
compress_statistics=False, | |
blocksize=64, | |
quant_type=self.quant_type, | |
quant_storage=torch.uint8, | |
module=self, | |
) | |
else: | |
self.weight = ForgeParams4bit( | |
state_dict[prefix + 'weight'].to(self.dummy), | |
requires_grad=False, | |
compress_statistics=True, | |
quant_type=self.quant_type, | |
quant_storage=torch.uint8, | |
module=self, | |
) | |
self.quant_state = self.weight.quant_state | |
if prefix + 'bias' in state_dict: | |
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) | |
del self.dummy | |
else: | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | |
current_device = None | |
current_dtype = None | |
current_manual_cast_enabled = False | |
current_bnb_dtype = None | |
current_nf4_version = 'v1' | |
import comfy.ops | |
class OPS(comfy.ops.manual_cast): | |
class Linear(ForgeLoader4Bit): | |
def __init__(self, *args, device=None, dtype=None, **kwargs): | |
super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def forward(self, x): | |
self.weight.quant_state = self.quant_state | |
if self.bias is not None and self.bias.dtype != x.dtype: | |
# Maybe this can also be set to all non-bnb ops since the cost is very low. | |
# And it only invokes one time, and most linear does not have bias | |
self.bias.data = self.bias.data.to(x.dtype) | |
if not self.parameters_manual_cast: | |
return functional_linear_4bits(x, self.weight, self.bias) | |
elif not self.weight.bnb_quantized: | |
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' | |
layer_original_device = self.weight.device | |
self.weight = self.weight._quantize(x.device) | |
bias = self.bias.to(x.device) if self.bias is not None else None | |
out = functional_linear_4bits(x, self.weight, bias) | |
self.weight = self.weight.to(layer_original_device) | |
return out | |
else: | |
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) | |
with main_stream_worker(weight, bias, signal): | |
return functional_linear_4bits(x, weight, bias) | |
class CheckpointLoaderNF4: | |
def INPUT_TYPES(s): | |
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), | |
}} | |
RETURN_TYPES = ("MODEL", "CLIP", "VAE") | |
FUNCTION = "load_checkpoint" | |
CATEGORY = "loaders" | |
def load_checkpoint(self, ckpt_name): | |
global current_nf4_version | |
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) | |
if 'bnb-nf4-v2' in ckpt_name: | |
current_nf4_version = 'v2' | |
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": OPS}) | |
return out[:3] | |
NODE_CLASS_MAPPINGS = { | |
"CheckpointLoaderNF4": CheckpointLoaderNF4, | |
} | |