ZIP / models /clip_ebc /convnext.py
Yiming-M's picture
2025-07-31 22:07 πŸ›
227d44e
raw
history blame
7.28 kB
from torch import nn, Tensor
import open_clip
from peft import get_peft_model, LoraConfig
from ..utils import ConvRefine, ConvAdapter
from ..utils import ConvUpsample, _get_norm_layer, _get_activation
convnext_names_and_weights = {
"convnext_base": ["laion400m_s13b_b51k"], # 107.49M
"convnext_base_w": ["laion2b_s13b_b82k", "laion2b_s13b_b82k_augreg", "laion_aesthetic_s13b_b82k"], # 107.75M
"convnext_base_w_320": ["laion_aesthetic_s13b_b82k", "laion_aesthetic_s13b_b82k_augreg"], # 107.75M
"convnext_large_d": ["laion2b_s26b_b102k_augreg"], # 217.46M
"convnext_large_d_320": ["laion2b_s29b_b131k_ft", "laion2b_s29b_b131k_ft_soup"], # 217.46M
"convnext_xxlarge": ["laion2b_s34b_b82k_augreg", "laion2b_s34b_b82k_augreg_rewind", "laion2b_s34b_b82k_augreg_soup"] # 896.88M
}
refiner_channels = {
"convnext_base": 1024,
"convnext_base_w": 1024,
"convnext_base_w_320": 1024,
"convnext_large_d": 1536,
"convnext_large_d_320": 1536,
"convnext_xxlarge": 3072,
}
refiner_groups = {
"convnext_base": 1,
"convnext_base_w": 1,
"convnext_base_w_320": 1,
"convnext_large_d": refiner_channels["convnext_large_d"] // 512, # 3
"convnext_large_d_320": refiner_channels["convnext_large_d_320"] // 512, # 3
"convnext_xxlarge": refiner_channels["convnext_xxlarge"] // 512, # 6
}
class ConvNeXt(nn.Module):
def __init__(
self,
model_name: str,
weight_name: str,
block_size: int = 16,
adapter: bool = False,
adapter_reduction: int = 4,
norm: str = "none",
act: str = "none"
) -> None:
super(ConvNeXt, self).__init__()
assert model_name in convnext_names_and_weights, f"Model name should be one of {list(convnext_names_and_weights.keys())}, but got {model_name}."
assert weight_name in convnext_names_and_weights[model_name], f"Pretrained should be one of {convnext_names_and_weights[model_name]}, but got {weight_name}."
assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
self.model_name, self.weight_name = model_name, weight_name
self.block_size = block_size
# model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
model = open_clip.create_model(model_name=model_name, pretrained=False, load_weights=False).visual
self.adapter = adapter
if adapter:
self.adapter_reduction = adapter_reduction
for param in model.parameters():
param.requires_grad = False
self.stem = model.trunk.stem
self.depth = len(model.trunk.stages)
for idx, stage in enumerate(model.trunk.stages):
setattr(self, f"stage{idx}", stage)
if adapter:
setattr(self, f"adapter{idx}", ConvAdapter(
in_channels=stage.blocks[-1].mlp.fc2.out_features,
bottleneck_channels=stage.blocks[-1].mlp.fc2.out_features // adapter_reduction,
) if idx < self.depth - 1 else nn.Identity()) # No adapter for the last stage
if self.model_name in ["convnext_base", "convnext_base_w", "convnext_base_w_320", "convnext_xxlarge"]:
self.in_features, self.out_features = model.head.proj.in_features, model.head.proj.out_features
else: # "convnext_large_d", "convnext_large_d_320":
self.in_features, self.out_features = model.head.mlp.fc1.in_features, model.head.mlp.fc2.out_features
if norm == "bn":
norm_layer = nn.BatchNorm2d
elif norm == "ln":
norm_layer = nn.LayerNorm
else:
norm_layer = _get_norm_layer(model)
if act == "relu":
activation = nn.ReLU(inplace=True)
elif act == "gelu":
activation = nn.GELU()
else:
activation = _get_activation(model)
if block_size == 32:
self.refiner = ConvRefine(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
)
elif block_size == 16:
self.refiner = ConvUpsample(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
)
else: # block_size == 8
self.refiner = nn.Sequential(
ConvUpsample(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
),
ConvUpsample(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
),
)
def train(self, mode: bool = True):
if self.adapter and mode:
# training:
self.stem.eval()
for idx in range(self.depth):
getattr(self, f"stage{idx}").eval()
getattr(self, f"adapter{idx}").train()
self.refiner.train()
else:
# evaluation:
for module in self.children():
module.train(mode)
def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
for idx in range(self.depth):
x = getattr(self, f"stage{idx}")(x)
if self.adapter:
x = getattr(self, f"adapter{idx}")(x)
x = self.refiner(x)
return x
def _convnext(
model_name: str,
weight_name: str,
block_size: int = 16,
adapter: bool = False,
adapter_reduction: int = 4,
lora: bool = False,
lora_rank: int = 16,
lora_alpha: float = 32.0,
lora_dropout: float = 0.1,
norm: str = "none",
act: str = "none"
) -> ConvNeXt:
assert not (lora and adapter), "Lora and adapter cannot be used together."
model = ConvNeXt(
model_name=model_name,
weight_name=weight_name,
block_size=block_size,
adapter=adapter,
adapter_reduction=adapter_reduction,
norm=norm,
act=act
)
if lora:
target_modules = []
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)) and "refiner" not in name:
target_modules.append(name)
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
target_modules=target_modules,
)
model = get_peft_model(model, lora_config)
# Unfreeze refiner
for name, module in model.named_modules():
if "refiner" in name:
module.requires_grad_(True)
return model