import types from typing import List, Callable import torch from torch import nn, Tensor from torch.nn import functional as F from torchvision.models.resnet import BasicBlock def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False): losses, rewards = criterion(logits, targets) returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device) if loss_normalization: coeff = torch.mean(losses).detach() embeds = [hidden_state] predictions = [] for k, w in enumerate(lambdas): embeds.append(trp_blocks[k](embeds[-1])) predictions.append(shared_head(embeds[-1])) returns = returns + w * rewards replica_losses, rewards = criterion(predictions[-1], targets, rewards) losses = losses + replica_losses loss = torch.mean(losses * returns) if loss_normalization: with torch.no_grad(): coeff = torch.exp(coeff) / torch.exp(loss.detach()) loss = coeff * loss return loss class TPBlock(nn.Module): def __init__(self, depths: int, inplanes: int, planes: int): super(TPBlock, self).__init__() blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)] self.blocks = nn.Sequential(*blocks) for name, param in self.blocks.named_parameters(): if 'conv' in name: nn.init.zeros_(param) # Initialize weights elif 'downsample' in name: nn.init.zeros_(param) # Initialize biases def forward(self, x): return self.blocks(x) class ResNetConfig: @staticmethod def gen_criterion(label_smoothing=0.0, top_k=1): def func(input, target, mask=None): """ Args: input (Tensor): Input tensor of shape [B, C]. target (Tensor): Target labels of shape [B] or [B, C]. Returns: loss (Tensor): Scalar tensor representing the loss. mask (Tensor): Boolean mask tensor of shape [B]. """ label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing) if mask is None: mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device) loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6) with torch.no_grad(): topk_values, topk_indices = torch.topk(input, top_k, dim=-1) mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype) return loss, mask return func @staticmethod def gen_shared_head(self): def func(x): """ Args: x (Tensor): Hidden States tensor of shape [B, C, H, Whidden_units]. Returns: logits (Tensor): Logits tensor of shape [B, C]. """ x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) logits = self.fc(x) return logits return func @staticmethod def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1): def func(self, x: Tensor, targets=None) -> Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) hidden_states = self.layer3(x) x = self.layer4(hidden_states) x = self.avgpool(x) x = torch.flatten(x, 1) logits = self.fc(x) if self.training: shared_head = ResNetConfig.gen_shared_head(self) criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k) loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, targets, loss_normalization=loss_normalization) return logits, loss return logits return func def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs): print("✅ Applying TRP to ResNet for Image Classification...") model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths]) model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model) return model