UniversalAlgorithmic's picture
Upload 13 files
97d8aaa verified
import types
from typing import Optional, List, Union, Callable
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.models.mobilenetv2 import MobileNetV2
from torchvision.models.resnet import ResNet
from torchvision.models.efficientnet import EfficientNet
from torchvision.models.vision_transformer import VisionTransformer
def compute_policy_loss(loss_sequence, mask_sequence, rewards):
losses = sum(mask * padded_loss for mask, padded_loss in zip(mask_sequence, loss_sequence))
returns = sum(padded_reward * mask for padded_reward, mask in zip(rewards, mask_sequence))
loss = torch.mean(losses * returns)
return loss
class TPBlock(nn.Module):
def __init__(self, depths: int, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> None:
super().__init__()
out_planes = in_planes if out_planes is None else out_planes
self.layers = torch.nn.ModuleList([self._make_layer(in_planes, out_planes, rank, shape_dims, channel_first, dtype) for _ in range(depths)])
def forward(self, x: Tensor) -> Tensor:
for layer in self.layers:
x = x + layer(x)
return x
def _make_layer(self, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> nn.Sequential:
class Permute(nn.Module):
def __init__(self, *dims):
super().__init__()
self.dims = dims
def forward(self, x):
return x.permute(*self.dims)
class RMSNorm(nn.Module):
__constants__ = ["eps"]
eps: float
def __init__(self, hidden_size, eps: float = 1e-6, device=None, dtype=None):
"""
LlamaRMSNorm is equivalent to T5LayerNorm.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs))
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(dim=1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
weight = self.weight.view(1, -1, *[1] * (hidden_states.ndim - 2))
return weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.eps}"
conv_map = {
2: (nn.Conv1d, (0, 2, 1), (0, 2, 1)),
3: (nn.Conv2d, (0, 3, 1, 2), (0, 2, 3, 1)),
4: (nn.Conv3d, (0, 4, 1, 2, 3), (0, 2, 3, 4, 1)),
}
Conv, pre_dims, post_dims = conv_map[shape_dims]
kernel_size, dilation, padding = self.generate_hyperparameters(rank)
pre_permute = nn.Identity() if channel_first else Permute(*pre_dims)
post_permute = nn.Identity() if channel_first else Permute(*post_dims)
conv1 = Conv(in_planes, out_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
nn.init.zeros_(conv1.weight)
bn1 = RMSNorm(out_planes, dtype=dtype, device="cuda")
relu = nn.ReLU(inplace=True)
conv2 = Conv(out_planes, in_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
nn.init.zeros_(conv2.weight)
bn2 = RMSNorm(in_planes, dtype=dtype, device="cuda")
return torch.nn.Sequential(pre_permute, conv1, bn1, relu, conv2, bn2, relu, post_permute)
@staticmethod
def generate_hyperparameters(rank: int):
"""
Generates kernel size and dilation rate pairs sorted by increasing padded kernel size.
Args:
rank: Number of (kernel_size, dilation) pairs to generate. Must be positive.
Returns:
Tuple[int, int]: A (kernel_size, dilation) tuple where:
- kernel_size: Always odd and >= 1
- dilation: Computed to maintain consistent padded kernel size growth
Note:
Padded kernel size is calculated as:
(kernel_size - 1) * dilation + 1
Pairs are generated first in order of increasing padded kernel size,
then by increasing kernel size for equal padded kernel sizes.
"""
pairs = [(1, 1, 0)] # Start with smallest possible
padded_kernel_size = 3
while len(pairs) < rank:
for kernel_size in range(3, padded_kernel_size + 1, 2):
if (padded_kernel_size - 1) % (kernel_size - 1) == 0:
dilation = (padded_kernel_size - 1) // (kernel_size - 1)
padding = dilation * (kernel_size - 1) // 2
pairs.append((kernel_size, dilation, padding))
if len(pairs) >= rank:
break
# Move to next odd padded kernel size
padded_kernel_size += 2
return pairs[-1]
class ResNetConfig:
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.fc(x)
return logits
return func
@staticmethod
def gen_logits(self, shared_head):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
Returns:
logits_seqence (List[Tensor]): List of Logits tensors.
"""
logits_sequence = [shared_head(hidden_states)]
for layer in self.trp_blocks:
logits_sequence.append(shared_head(layer(hidden_states)))
return logits_sequence
return func
@staticmethod
def gen_mask(label_smoothing=0.0, top_k=1):
def func(logits_sequence, labels):
"""
Args:
logits_sequence (List[Tensor]): List of Logits tensors.
labels (Tensor): Target labels of shape [B] or [B, C].
Returns:
mask_sequence (List[Tensor]): List of Mask tensor.
returns (Tensor): Boolean mask tensor of shape [B*(L-1)].
"""
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
for logits in logits_sequence:
with torch.no_grad():
topk_values, topk_indices = torch.topk(logits, top_k, dim=-1)
mask = torch.eq(topk_indices, labels[:, None]).any(dim=-1).to(torch.float32)
mask_sequence.append(mask_sequence[-1] * mask)
return mask_sequence
return func
@staticmethod
def gen_criterion(label_smoothing=0.0):
def func(logits_sequence, labels):
"""
Args:
logits_sequence (List[Tensor]): List of Logits tensor.
labels (Tensor): labels labels of shape [B] or [B, C].
Returns:
loss (Tensor): Scalar tensor representing the loss.
mask (Tensor): Boolean mask tensor of shape [B].
"""
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
loss_sequence = []
for logits in logits_sequence:
loss_sequence.append(F.cross_entropy(logits, labels, reduction="none", label_smoothing=label_smoothing))
return loss_sequence
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, x: Tensor, targets=None) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
hidden_states = self.layer4(x)
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.fc(x)
if self.training:
shared_head = ResNetConfig.gen_shared_head(self)
compute_logits = ResNetConfig.gen_logits(self, shared_head)
compute_mask = ResNetConfig.gen_mask(label_smoothing, top_k)
compute_loss = ResNetConfig.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
class MobileNetV2Config(ResNetConfig):
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
x = torch.flatten(x, 1)
logits = self.classifier(x)
return logits
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, x: Tensor, targets=None) -> Tensor:
hidden_states = self.features(x)
# Cannot use "squeeze" as batch-size can be 1
x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
x = torch.flatten(x, 1)
logits = self.classifier(x)
if self.training:
shared_head = MobileNetV2Config.gen_shared_head(self)
compute_logits = MobileNetV2Config.gen_logits(self, shared_head)
compute_mask = MobileNetV2Config.gen_mask(label_smoothing, top_k)
compute_loss = MobileNetV2Config.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
class EfficientNetConfig(ResNetConfig):
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.classifier(x)
return logits
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, x: Tensor, targets=None) -> Tensor:
hidden_states = self.features(x)
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.classifier(x)
if self.training:
shared_head = EfficientNetConfig.gen_shared_head(self)
compute_logits = EfficientNetConfig.gen_logits(self, shared_head)
compute_mask = EfficientNetConfig.gen_mask(label_smoothing, top_k)
compute_loss = EfficientNetConfig.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
class VisionTransformerConfig(ResNetConfig):
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = hidden_states[:, 0]
logits = self.heads(x)
return logits
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, images: Tensor, targets=None):
x = self._process_input(images)
n = x.shape[0]
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
hidden_states = self.encoder(x)
x = hidden_states[:, 0]
logits = self.heads(x)
if self.training:
shared_head = VisionTransformerConfig.gen_shared_head(self)
compute_logits = VisionTransformerConfig.gen_logits(self, shared_head)
compute_mask = VisionTransformerConfig.gen_mask(label_smoothing, top_k)
compute_loss = VisionTransformerConfig.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
def apply_trp(model, depths: List[int], in_planes: int, out_planes: int, rewards, **kwargs):
if isinstance(model, ResNet):
print("✅ Applying TRP to ResNet for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(ResNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
elif isinstance(model, MobileNetV2):
print("✅ Applying TRP to MobileNetV2 for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(MobileNetV2Config.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
elif isinstance(model, EfficientNet):
print("✅ Applying TRP to EfficientNet for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(EfficientNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
elif isinstance(model, VisionTransformer):
print("✅ Applying TRP to VisionTransformer for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k, shape_dims=2, channel_first=False) for k, d in enumerate(depths)])
model.forward = types.MethodType(VisionTransformerConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
return model