UniversalAlgorithmic's picture
Upload 13 files
97d8aaa verified
import types
from typing import Optional, List, Union, Callable
from collections import OrderedDict
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.models.mobilenetv2 import MobileNetV2
from torchvision.models.resnet import ResNet
from torchvision.models.efficientnet import EfficientNet
from torchvision.models.vision_transformer import VisionTransformer
from torchvision.models.segmentation.fcn import FCN
from torchvision.models.segmentation.deeplabv3 import DeepLabV3
def compute_policy_loss(loss_sequence, mask_sequence, rewards):
losses = sum(mask * padded_loss for mask, padded_loss in zip(mask_sequence, loss_sequence))
returns = sum(padded_reward * mask for padded_reward, mask in zip(rewards, mask_sequence))
loss = torch.mean(losses * returns)
return loss
class TPBlock(nn.Module):
def __init__(self, depths: int, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> None:
super().__init__()
out_planes = in_planes if out_planes is None else out_planes
self.layers = torch.nn.ModuleList([self._make_layer(in_planes, out_planes, rank, shape_dims, channel_first, dtype) for _ in range(depths)])
def forward(self, x: Tensor) -> Tensor:
for layer in self.layers:
x = x + layer(x)
return x
def _make_layer(self, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> nn.Sequential:
class Permute(nn.Module):
def __init__(self, *dims):
super().__init__()
self.dims = dims
def forward(self, x):
return x.permute(*self.dims)
class RMSNorm(nn.Module):
__constants__ = ["eps"]
eps: float
def __init__(self, hidden_size, eps: float = 1e-6, device=None, dtype=None):
"""
LlamaRMSNorm is equivalent to T5LayerNorm.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs))
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(dim=1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
weight = self.weight.view(1, -1, *[1] * (hidden_states.ndim - 2))
return weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.eps}"
conv_map = {
2: (nn.Conv1d, (0, 2, 1), (0, 2, 1)),
3: (nn.Conv2d, (0, 3, 1, 2), (0, 2, 3, 1)),
4: (nn.Conv3d, (0, 4, 1, 2, 3), (0, 2, 3, 4, 1)),
}
Conv, pre_dims, post_dims = conv_map[shape_dims]
kernel_size, dilation, padding = self.generate_hyperparameters(rank)
pre_permute = nn.Identity() if channel_first else Permute(*pre_dims)
post_permute = nn.Identity() if channel_first else Permute(*post_dims)
conv1 = Conv(in_planes, out_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
nn.init.zeros_(conv1.weight)
bn1 = RMSNorm(out_planes, dtype=dtype, device="cuda")
relu = nn.ReLU(inplace=True)
conv2 = Conv(out_planes, in_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
nn.init.zeros_(conv2.weight)
bn2 = RMSNorm(in_planes, dtype=dtype, device="cuda")
return torch.nn.Sequential(pre_permute, conv1, bn1, relu, conv2, bn2, relu, post_permute)
@staticmethod
def generate_hyperparameters(rank: int):
"""
Generates kernel size and dilation rate pairs sorted by increasing padded kernel size.
Args:
rank: Number of (kernel_size, dilation) pairs to generate. Must be positive.
Returns:
Tuple[int, int]: A (kernel_size, dilation) tuple where:
- kernel_size: Always odd and >= 1
- dilation: Computed to maintain consistent padded kernel size growth
Note:
Padded kernel size is calculated as:
(kernel_size - 1) * dilation + 1
Pairs are generated first in order of increasing padded kernel size,
then by increasing kernel size for equal padded kernel sizes.
"""
pairs = [(1, 1, 0)] # Start with smallest possible
padded_kernel_size = 3
while len(pairs) < rank:
for kernel_size in range(3, padded_kernel_size + 1, 2):
if (padded_kernel_size - 1) % (kernel_size - 1) == 0:
dilation = (padded_kernel_size - 1) // (kernel_size - 1)
padding = dilation * (kernel_size - 1) // 2
pairs.append((kernel_size, dilation, padding))
if len(pairs) >= rank:
break
# Move to next odd padded kernel size
padded_kernel_size += 2
return pairs[-1]
# ResNet for Image Classification
class ResNetConfig:
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.fc(x)
return logits
return func
@staticmethod
def gen_logits(self, shared_head):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
Returns:
logits_seqence (List[Tensor]): List of Logits tensors.
"""
logits_sequence = [shared_head(hidden_states)]
for layer in self.trp_blocks:
logits_sequence.append(shared_head(layer(hidden_states)))
return logits_sequence
return func
@staticmethod
def gen_mask(label_smoothing=0.0, top_k=1):
def func(logits_sequence, labels):
"""
Args:
logits_sequence (List[Tensor]): List of Logits tensors.
labels (Tensor): Target labels of shape [B] or [B, C].
Returns:
mask_sequence (List[Tensor]): Boolean mask tensor of shape [B*(L-1)].
"""
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
for logits in logits_sequence:
with torch.no_grad():
topk_values, topk_indices = torch.topk(logits, top_k, dim=-1)
mask = torch.eq(topk_indices, labels[:, None]).any(dim=-1).to(torch.float32)
mask_sequence.append(mask_sequence[-1] * mask)
return mask_sequence
return func
@staticmethod
def gen_criterion(label_smoothing=0.0):
def func(logits_sequence, labels):
"""
Args:
logits_sequence (List[Tensor]): List of Logits tensor.
labels (Tensor): labels labels of shape [B] or [B, C].
Returns:
loss (Tensor): Scalar tensor representing the loss.
mask (Tensor): Boolean mask tensor of shape [B].
"""
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
loss_sequence = []
for logits in logits_sequence:
loss_sequence.append(F.cross_entropy(logits, labels, reduction="none", label_smoothing=label_smoothing))
return loss_sequence
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, x: Tensor, targets=None) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
hidden_states = self.layer4(x)
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.fc(x)
if self.training:
shared_head = ResNetConfig.gen_shared_head(self)
compute_logits = ResNetConfig.gen_logits(self, shared_head)
compute_mask = ResNetConfig.gen_mask(label_smoothing, top_k)
compute_loss = ResNetConfig.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
# MobileNetV2 for Image Classification
class MobileNetV2Config(ResNetConfig):
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
x = torch.flatten(x, 1)
logits = self.classifier(x)
return logits
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, x: Tensor, targets=None) -> Tensor:
hidden_states = self.features(x)
# Cannot use "squeeze" as batch-size can be 1
x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
x = torch.flatten(x, 1)
logits = self.classifier(x)
if self.training:
shared_head = MobileNetV2Config.gen_shared_head(self)
compute_logits = MobileNetV2Config.gen_logits(self, shared_head)
compute_mask = MobileNetV2Config.gen_mask(label_smoothing, top_k)
compute_loss = MobileNetV2Config.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
# EfficientNet for Image Classification
class EfficientNetConfig(ResNetConfig):
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.classifier(x)
return logits
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, x: Tensor, targets=None) -> Tensor:
hidden_states = self.features(x)
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.classifier(x)
if self.training:
shared_head = EfficientNetConfig.gen_shared_head(self)
compute_logits = EfficientNetConfig.gen_logits(self, shared_head)
compute_mask = EfficientNetConfig.gen_mask(label_smoothing, top_k)
compute_loss = EfficientNetConfig.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
# VisionTransformer for Image Classification
class VisionTransformerConfig(ResNetConfig):
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = hidden_states[:, 0]
logits = self.heads(x)
return logits
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, images: Tensor, targets=None):
x = self._process_input(images)
n = x.shape[0]
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
hidden_states = self.encoder(x)
x = hidden_states[:, 0]
logits = self.heads(x)
if self.training:
shared_head = VisionTransformerConfig.gen_shared_head(self)
compute_logits = VisionTransformerConfig.gen_logits(self, shared_head)
compute_mask = VisionTransformerConfig.gen_mask(label_smoothing, top_k)
compute_loss = VisionTransformerConfig.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
# FCN for Semantic Segmentation
class FCNConfig(ResNetConfig):
@staticmethod
def gen_out_shared_head(self, input_shape):
def func(features):
"""
Args:
features (Tensor): features tensor of shape [B, hidden_units, H, W].
Returns:
result (Tensors): result tensor of shape [B, C, H, W].
"""
x = self.classifier(features)
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
return result
return func
@staticmethod
def gen_aux_shared_head(self, input_shape):
def func(features):
"""
Args:
features (Tensor): features tensor of shape [B, hidden_units, H, W].
Returns:
result (Tensors): result tensor of shape [B, C, H, W].
"""
x = self.aux_classifier(features)
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
return result
return func
@staticmethod
def gen_out_logits(self, shared_head):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
Returns:
logits_seqence (List[Tensor]): List of Logits tensors.
"""
logits_sequence = [shared_head(hidden_states)]
for layer in self.out_trp_blocks:
logits_sequence.append(shared_head(layer(hidden_states)))
return logits_sequence
return func
@staticmethod
def gen_aux_logits(self, shared_head):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
Returns:
logits_seqence (List[Tensor]): List of Logits tensors.
"""
logits_sequence = [shared_head(hidden_states)]
for layer in self.aux_trp_blocks:
logits_sequence.append(shared_head(layer(hidden_states)))
return logits_sequence
return func
@staticmethod
def gen_mask(label_smoothing=0.0, top_k=1):
def func(logits_sequence, labels):
"""
Args:
logits_sequence (List[Tensor]): List of Logits tensors with shape [B, C, H, W].
labels (Tensor): Target labels of shape [B, H, W].
Returns:
mask_sequence (List[Tensor]): Boolean mask tensor of shape [B, H, W].
"""
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
for logits in logits_sequence:
with torch.no_grad():
topk_values, topk_indices = torch.topk(logits, top_k, dim=1)
mask = torch.eq(topk_indices, labels[:, None, :, :]).any(dim=1).to(torch.float32)
mask_sequence.append(mask_sequence[-1] * mask)
return mask_sequence
return func
@staticmethod
def gen_criterion(label_smoothing=0.0):
def func(logits_sequence, labels):
"""
Args:
logits_sequence (List[Tensor]): List of Logits tensor.
labels (Tensor): labels labels of shape [B] or [B, C].
Returns:
loss (Tensor): Scalar tensor representing the loss.
mask (Tensor): Boolean mask tensor of shape [B].
"""
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
loss_sequence = []
for logits in logits_sequence:
loss_sequence.append(F.cross_entropy(logits, labels, ignore_index=255, reduction="none", label_smoothing=label_smoothing))
return loss_sequence
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, images: Tensor, targets=None):
input_shape = images.shape[-2:]
# contract: features is a dict of tensors
features = self.backbone(images)
result = OrderedDict()
x = features["out"]
x = self.classifier(x)
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
result["out"] = x
if self.aux_classifier is not None:
x = features["aux"]
x = self.aux_classifier(x)
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
result["aux"] = x
if self.training:
torch._assert(targets is not None, "targets should not be none when in training mode")
out_shared_head = FCNConfig.gen_out_shared_head(self, input_shape)
aux_shared_head = FCNConfig.gen_aux_shared_head(self, input_shape)
compute_out_logits = FCNConfig.gen_out_logits(self, out_shared_head)
compute_aux_logits = FCNConfig.gen_aux_logits(self, aux_shared_head)
compute_mask = FCNConfig.gen_mask(label_smoothing, top_k)
compute_loss = FCNConfig.gen_criterion(label_smoothing)
out_logits_sequence = compute_out_logits(features["out"])
out_mask_sequence = compute_mask(out_logits_sequence, targets)
out_loss_sequence = compute_loss(out_logits_sequence, targets)
out_loss = compute_policy_loss(out_loss_sequence, out_mask_sequence, rewards)
aux_logits_sequence = compute_aux_logits(features["aux"])
aux_mask_sequence = compute_mask(aux_logits_sequence, targets)
aux_loss_sequence = compute_loss(aux_logits_sequence, targets)
aux_loss = compute_policy_loss(aux_loss_sequence, aux_mask_sequence, rewards)
loss = out_loss + 0.5 * aux_loss
return result, loss
return result
return func
# DeepLabV3Config for Semantic Segmentation
class DeepLabV3Config(FCNConfig):
pass
def apply_trp(model, depths: List[int], in_planes: int, out_planes: int, rewards, **kwargs):
if isinstance(model, ResNet):
print("βœ… Applying TRP to ResNet for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(ResNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
elif isinstance(model, MobileNetV2):
print("βœ… Applying TRP to MobileNetV2 for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(MobileNetV2Config.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
elif isinstance(model, EfficientNet):
print("βœ… Applying TRP to EfficientNet for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(EfficientNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
elif isinstance(model, VisionTransformer):
print("βœ… Applying TRP to VisionTransformer for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k, shape_dims=2, channel_first=False) for k, d in enumerate(depths)])
model.forward = types.MethodType(VisionTransformerConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
elif isinstance(model, FCN):
print("βœ… Applying TRP to FCN for Semantic Segmentation...")
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=2048, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=1024, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(FCNConfig.gen_forward(rewards, label_smoothing=0.0, top_k=1), model)
elif isinstance(model, DeepLabV3):
print("βœ… Applying TRP to DeepLabV3 for Semantic Segmentation...")
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, in_planes=2048, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, in_planes=1024, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(DeepLabV3Config.gen_forward(rewards, label_smoothing=0.0, top_k=1), model)
return model