"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper """ import torch import torch.nn as nn def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs): weight_dtype = cls.weight.dtype cls.to(origin_dtype) # Convert all inputs to the original dtype inputs = [input.to(origin_dtype) for input in inputs] out = cls.original_forward(*inputs, **kwargs) cls.to(weight_dtype) return out def replace_parameters_by_name(module, name_keywords, device): from torch import nn for name, param in list(module.named_parameters(recurse=False)): if any(keyword in name for keyword in name_keywords): if isinstance(param, nn.Parameter): tensor = param.data delattr(module, name) setattr(module, name, tensor.to(device=device)) for child_name, child_module in module.named_children(): replace_parameters_by_name(child_module, name_keywords, device) def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']): for name, module in model.named_modules(): flag = False for _exclude_module_name in exclude_module_name: if _exclude_module_name in name: flag = True if flag: continue for param_name, param in module.named_parameters(): flag = False for _exclude_module_name in exclude_module_name: if _exclude_module_name in param_name: flag = True if flag: continue param.data = param.data.to(torch.float8_e4m3fn) def convert_weight_dtype_wrapper(module, origin_dtype): for name, module in module.named_modules(): if name == "" or "embed_tokens" in name: continue original_forward = module.forward if hasattr(module, "weight") and module.weight is not None: setattr(module, "original_forward", original_forward) setattr( module, "forward", lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs) )