|
from functools import cache |
|
|
|
SUPPORT_QUANT = False |
|
try: |
|
from bitsandbytes.nn import LinearNF4, Linear8bitLt, LinearFP4 |
|
|
|
SUPPORT_QUANT = True |
|
except Exception: |
|
import torch.nn as nn |
|
|
|
class LinearNF4(nn.Linear): |
|
pass |
|
|
|
class Linear8bitLt(nn.Linear): |
|
pass |
|
|
|
class LinearFP4(nn.Linear): |
|
pass |
|
|
|
|
|
try: |
|
from quanto.nn import QLinear, QConv2d, QLayerNorm |
|
|
|
SUPPORT_QUANT = True |
|
except Exception: |
|
import torch.nn as nn |
|
|
|
class QLinear(nn.Linear): |
|
pass |
|
|
|
class QConv2d(nn.Conv2d): |
|
pass |
|
|
|
class QLayerNorm(nn.LayerNorm): |
|
pass |
|
|
|
|
|
try: |
|
from optimum.quanto.nn import ( |
|
QLinear as QLinearOpt, |
|
QConv2d as QConv2dOpt, |
|
QLayerNorm as QLayerNormOpt, |
|
) |
|
|
|
SUPPORT_QUANT = True |
|
except Exception: |
|
import torch.nn as nn |
|
|
|
class QLinearOpt(nn.Linear): |
|
pass |
|
|
|
class QConv2dOpt(nn.Conv2d): |
|
pass |
|
|
|
class QLayerNormOpt(nn.LayerNorm): |
|
pass |
|
|
|
|
|
from ..logging import logger |
|
|
|
|
|
QuantLinears = ( |
|
Linear8bitLt, |
|
LinearFP4, |
|
LinearNF4, |
|
QLinear, |
|
QConv2d, |
|
QLayerNorm, |
|
QLinearOpt, |
|
QConv2dOpt, |
|
QLayerNormOpt, |
|
) |
|
|
|
|
|
@cache |
|
def log_bypass(): |
|
return logger.warning( |
|
"Using bnb/quanto/optimum-quanto with LyCORIS will enable force-bypass mode." |
|
) |
|
|
|
|
|
@cache |
|
def log_suspect(): |
|
return logger.warning( |
|
"Non-native Linear detected but bypass_mode is not set. " |
|
"Automatically using force-bypass mode to avoid possible issues. " |
|
"Please set bypass_mode=False explicitly if there are no quantized layers." |
|
) |
|
|