UniversalAlgorithmic's picture
Upload 185 files
2215b89 verified
import types
from typing import List, Callable
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.models.resnet import BasicBlock
def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
losses, rewards = criterion(logits, targets)
returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device)
if loss_normalization:
coeff = torch.mean(losses).detach()
embeds = [hidden_state]
predictions = []
for k, w in enumerate(lambdas):
embeds.append(trp_blocks[k](embeds[-1]))
predictions.append(shared_head(embeds[-1]))
returns = returns + w * rewards
replica_losses, rewards = criterion(predictions[-1], targets, rewards)
losses = losses + replica_losses
loss = torch.mean(losses * returns)
if loss_normalization:
with torch.no_grad():
coeff = torch.exp(coeff) / torch.exp(loss.detach())
loss = coeff * loss
return loss
class TPBlock(nn.Module):
def __init__(self, depths: int, inplanes: int, planes: int):
super(TPBlock, self).__init__()
blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)]
self.blocks = nn.Sequential(*blocks)
for name, param in self.blocks.named_parameters():
if 'conv' in name:
nn.init.zeros_(param) # Initialize weights
elif 'downsample' in name:
nn.init.zeros_(param) # Initialize biases
def forward(self, x):
return self.blocks(x)
class ResNetConfig:
@staticmethod
def gen_criterion(label_smoothing=0.0, top_k=1):
def func(input, target, mask=None):
"""
Args:
input (Tensor): Input tensor of shape [B, C].
target (Tensor): Target labels of shape [B] or [B, C].
Returns:
loss (Tensor): Scalar tensor representing the loss.
mask (Tensor): Boolean mask tensor of shape [B].
"""
label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
if mask is None:
mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
with torch.no_grad():
topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)
return loss, mask
return func
@staticmethod
def gen_shared_head(self):
def func(x):
"""
Args:
x (Tensor): Hidden States tensor of shape [B, C, H, Whidden_units].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
logits = self.fc(x)
return logits
return func
@staticmethod
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
def func(self, x: Tensor, targets=None) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
hidden_states = self.layer3(x)
x = self.layer4(hidden_states)
x = self.avgpool(x)
x = torch.flatten(x, 1)
logits = self.fc(x)
if self.training:
shared_head = ResNetConfig.gen_shared_head(self)
criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k)
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, targets, loss_normalization=loss_normalization)
return logits, loss
return logits
return func
def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs):
print("✅ Applying TRP to ResNet for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths])
model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
return model