|
import types |
|
from typing import List, Callable |
|
|
|
import torch |
|
from torch import nn, Tensor |
|
from torch.nn import functional as F |
|
from torchvision.models.resnet import BasicBlock |
|
|
|
|
|
def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False): |
|
losses, rewards = criterion(logits, targets) |
|
returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device) |
|
if loss_normalization: |
|
coeff = torch.mean(losses).detach() |
|
|
|
embeds = [hidden_state] |
|
predictions = [] |
|
for k, w in enumerate(lambdas): |
|
embeds.append(trp_blocks[k](embeds[-1])) |
|
predictions.append(shared_head(embeds[-1])) |
|
returns = returns + w * rewards |
|
replica_losses, rewards = criterion(predictions[-1], targets, rewards) |
|
losses = losses + replica_losses |
|
loss = torch.mean(losses * returns) |
|
|
|
if loss_normalization: |
|
with torch.no_grad(): |
|
coeff = torch.exp(coeff) / torch.exp(loss.detach()) |
|
loss = coeff * loss |
|
|
|
return loss |
|
|
|
|
|
class TPBlock(nn.Module): |
|
def __init__(self, depths: int, inplanes: int, planes: int): |
|
super(TPBlock, self).__init__() |
|
|
|
blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)] |
|
self.blocks = nn.Sequential(*blocks) |
|
for name, param in self.blocks.named_parameters(): |
|
if 'conv' in name: |
|
nn.init.zeros_(param) |
|
elif 'downsample' in name: |
|
nn.init.zeros_(param) |
|
|
|
def forward(self, x): |
|
return self.blocks(x) |
|
|
|
|
|
class ResNetConfig: |
|
@staticmethod |
|
def gen_criterion(label_smoothing=0.0, top_k=1): |
|
def func(input, target, mask=None): |
|
""" |
|
Args: |
|
input (Tensor): Input tensor of shape [B, C]. |
|
target (Tensor): Target labels of shape [B] or [B, C]. |
|
|
|
Returns: |
|
loss (Tensor): Scalar tensor representing the loss. |
|
mask (Tensor): Boolean mask tensor of shape [B]. |
|
""" |
|
label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target |
|
|
|
unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing) |
|
if mask is None: |
|
mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device) |
|
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6) |
|
|
|
with torch.no_grad(): |
|
topk_values, topk_indices = torch.topk(input, top_k, dim=-1) |
|
mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype) |
|
|
|
return loss, mask |
|
return func |
|
|
|
@staticmethod |
|
def gen_shared_head(self): |
|
def func(x): |
|
""" |
|
Args: |
|
x (Tensor): Hidden States tensor of shape [B, C, H, Whidden_units]. |
|
|
|
Returns: |
|
logits (Tensor): Logits tensor of shape [B, C]. |
|
""" |
|
x = self.layer4(x) |
|
x = self.avgpool(x) |
|
x = torch.flatten(x, 1) |
|
logits = self.fc(x) |
|
return logits |
|
return func |
|
|
|
@staticmethod |
|
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1): |
|
def func(self, x: Tensor, targets=None) -> Tensor: |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
hidden_states = self.layer3(x) |
|
x = self.layer4(hidden_states) |
|
x = self.avgpool(x) |
|
x = torch.flatten(x, 1) |
|
logits = self.fc(x) |
|
|
|
if self.training: |
|
shared_head = ResNetConfig.gen_shared_head(self) |
|
criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k) |
|
|
|
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, targets, loss_normalization=loss_normalization) |
|
|
|
return logits, loss |
|
|
|
return logits |
|
|
|
return func |
|
|
|
|
|
def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs): |
|
print("✅ Applying TRP to ResNet for Image Classification...") |
|
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths]) |
|
model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model) |
|
return model |
|
|