|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.models as models |
|
import math |
|
|
|
|
|
"""Custom DenseNet Backbone""" |
|
class DenseBlock(nn.Module): |
|
""" |
|
Basic DenseNet block |
|
""" |
|
def __init__(self, in_channels, growth_rate, num_layers): |
|
super(DenseBlock, self).__init__() |
|
self.layers = nn.ModuleList() |
|
for i in range(num_layers): |
|
self.layers.append(self._make_layer(in_channels + i * growth_rate, growth_rate)) |
|
|
|
def _make_layer(self, in_channels, growth_rate): |
|
layer = nn.Sequential( |
|
nn.BatchNorm2d(in_channels), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False), |
|
nn.BatchNorm2d(4 * growth_rate), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) |
|
) |
|
return layer |
|
|
|
def forward(self, x): |
|
features = [x] |
|
for layer in self.layers: |
|
new_feature = layer(torch.cat(features, dim=1)) |
|
features.append(new_feature) |
|
return torch.cat(features, dim=1) |
|
|
|
|
|
class TransitionLayer(nn.Module): |
|
""" |
|
Transition layer between DenseBlocks |
|
""" |
|
def __init__(self, in_channels, out_channels): |
|
super(TransitionLayer, self).__init__() |
|
self.transition = nn.Sequential( |
|
nn.BatchNorm2d(in_channels), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), |
|
nn.AvgPool2d(kernel_size=2, stride=2) |
|
) |
|
|
|
def forward(self, x): |
|
return self.transition(x) |
|
|
|
|
|
class DenseNetBackbone(nn.Module): |
|
""" |
|
DenseNet backbone for CAN |
|
""" |
|
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64): |
|
super(DenseNetBackbone, self).__init__() |
|
|
|
|
|
self.features = nn.Sequential( |
|
nn.Conv2d(1, num_init_features, kernel_size=7, stride=2, padding=3, bias=False), |
|
nn.BatchNorm2d(num_init_features), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
) |
|
|
|
|
|
num_features = num_init_features |
|
for i, num_layers in enumerate(block_config): |
|
block = DenseBlock(num_features, growth_rate, num_layers) |
|
self.features.add_module(f'denseblock{i+1}', block) |
|
num_features = num_features + growth_rate * num_layers |
|
if i != len(block_config) - 1: |
|
trans = TransitionLayer(num_features, num_features // 2) |
|
self.features.add_module(f'transition{i+1}', trans) |
|
num_features = num_features // 2 |
|
|
|
|
|
self.features.add_module('norm5', nn.BatchNorm2d(num_features)) |
|
self.features.add_module('relu5', nn.ReLU(inplace=True)) |
|
|
|
self.out_channels = num_features |
|
|
|
def forward(self, x): |
|
return self.features(x) |
|
|
|
|
|
"""Pretrained DenseNet""" |
|
class DenseNetFeatureExtractor(nn.Module): |
|
def __init__(self, densenet_model, out_channels=684): |
|
super().__init__() |
|
|
|
self.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
|
self.conv0.weight.data = densenet_model.features.conv0.weight.data.mean(dim=1, keepdim=True) |
|
self.features = densenet_model.features |
|
self.out_channels = out_channels |
|
|
|
self.final_conv = nn.Conv2d(1024, out_channels, kernel_size=1) |
|
self.final_bn = nn.BatchNorm2d(out_channels) |
|
self.final_relu = nn.ReLU(inplace=True) |
|
|
|
def forward(self, x): |
|
x = self.conv0(x) |
|
x = self.features.norm0(x) |
|
x = self.features.relu0(x) |
|
x = self.features.pool0(x) |
|
x = self.features.denseblock1(x) |
|
x = self.features.transition1(x) |
|
x = self.features.denseblock2(x) |
|
x = self.features.transition2(x) |
|
x = self.features.denseblock3(x) |
|
x = self.features.transition3(x) |
|
x = self.features.denseblock4(x) |
|
x = self.features.norm5(x) |
|
x = self.final_conv(x) |
|
x = self.final_bn(x) |
|
x = self.final_relu(x) |
|
return x |
|
|
|
|
|
"""Custom ResNet Backbone""" |
|
class BasicBlock(nn.Module): |
|
""" |
|
Basic ResNet block |
|
""" |
|
expansion = 1 |
|
|
|
def __init__(self, in_channels, out_channels, stride=1): |
|
super(BasicBlock, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) |
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
|
self.shortcut = nn.Sequential() |
|
if stride != 1 or in_channels != out_channels * self.expansion: |
|
self.shortcut = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False), |
|
nn.BatchNorm2d(out_channels * self.expansion) |
|
) |
|
|
|
def forward(self, x): |
|
identity = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
|
|
out += self.shortcut(identity) |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class Bottleneck(nn.Module): |
|
""" |
|
Bottleneck ResNet block |
|
""" |
|
expansion = 4 |
|
|
|
def __init__(self, in_channels, out_channels, stride=1): |
|
super(Bottleneck, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) |
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False) |
|
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) |
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
self.shortcut = nn.Sequential() |
|
if stride != 1 or in_channels != out_channels * self.expansion: |
|
self.shortcut = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False), |
|
nn.BatchNorm2d(out_channels * self.expansion) |
|
) |
|
|
|
def forward(self, x): |
|
identity = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv3(out) |
|
out = self.bn3(out) |
|
|
|
out += self.shortcut(identity) |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class ResNetBackbone(nn.Module): |
|
""" |
|
ResNet backbone for CAN model, designed to output similar dimensions as DenseNet |
|
""" |
|
def __init__(self, block_type='bottleneck', layers=[3, 4, 6, 3], num_init_features=64): |
|
super(ResNetBackbone, self).__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(1, num_init_features, kernel_size=7, stride=2, padding=3, bias=False) |
|
self.bn1 = nn.BatchNorm2d(num_init_features) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
|
|
|
if block_type == 'basic': |
|
block = BasicBlock |
|
expansion = 1 |
|
elif block_type == 'bottleneck': |
|
block = Bottleneck |
|
expansion = 4 |
|
else: |
|
raise ValueError(f"Unknown block type: {block_type}") |
|
|
|
|
|
self.layer1 = self._make_layer(block, num_init_features, 64, layers[0], stride=1) |
|
self.layer2 = self._make_layer(block, 64 * expansion, 128, layers[1], stride=2) |
|
self.layer3 = self._make_layer(block, 128 * expansion, 256, layers[2], stride=2) |
|
self.layer4 = self._make_layer(block, 256 * expansion, 512, layers[3], stride=2) |
|
|
|
|
|
self.final_conv = nn.Conv2d(512 * expansion, 684, kernel_size=1) |
|
self.final_bn = nn.BatchNorm2d(684) |
|
self.final_relu = nn.ReLU(inplace=True) |
|
|
|
self.out_channels = 684 |
|
|
|
|
|
self._initialize_weights() |
|
|
|
def _make_layer(self, block, in_channels, out_channels, num_blocks, stride): |
|
layers = [] |
|
layers.append(block(in_channels, out_channels, stride)) |
|
for _ in range(1, num_blocks): |
|
layers.append(block(out_channels * block.expansion, out_channels)) |
|
return nn.Sequential(*layers) |
|
|
|
def _initialize_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
elif isinstance(m, nn.BatchNorm2d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
|
|
x = self.final_conv(x) |
|
x = self.final_bn(x) |
|
x = self.final_relu(x) |
|
|
|
return x |
|
|
|
|
|
|
|
"""Pretrained ResNet""" |
|
class ResNetFeatureExtractor(nn.Module): |
|
def __init__(self, resnet_model, out_channels=684): |
|
super().__init__() |
|
|
|
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
self.conv1.weight.data = resnet_model.conv1.weight.data.sum(dim=1, keepdim=True) |
|
self.bn1 = resnet_model.bn1 |
|
self.relu = resnet_model.relu |
|
self.maxpool = resnet_model.maxpool |
|
self.layer1 = resnet_model.layer1 |
|
self.layer2 = resnet_model.layer2 |
|
self.layer3 = resnet_model.layer3 |
|
self.layer4 = resnet_model.layer4 |
|
|
|
self.final_conv = nn.Conv2d(2048, out_channels, kernel_size=1) |
|
self.final_bn = nn.BatchNorm2d(out_channels) |
|
self.final_relu = nn.ReLU(inplace=True) |
|
self.out_channels = out_channels |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
x = self.final_conv(x) |
|
x = self.final_bn(x) |
|
x = self.final_relu(x) |
|
return x |
|
|
|
|
|
"""Channel Attention""" |
|
class ChannelAttention(nn.Module): |
|
""" |
|
Channel-wise attention mechanism |
|
""" |
|
def __init__(self, in_channels, ratio=16): |
|
super(ChannelAttention, self).__init__() |
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.max_pool = nn.AdaptiveMaxPool2d(1) |
|
|
|
self.fc = nn.Sequential( |
|
nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, bias=False), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, bias=False) |
|
) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
avg_out = self.fc(self.avg_pool(x)) |
|
max_out = self.fc(self.max_pool(x)) |
|
out = avg_out + max_out |
|
return self.sigmoid(out) |
|
|
|
|
|
"""Multi-scale Couting Module""" |
|
class MSCM(nn.Module): |
|
""" |
|
Multi-Scale Counting Module |
|
""" |
|
def __init__(self, in_channels, num_classes): |
|
super(MSCM, self).__init__() |
|
|
|
|
|
self.branch1 = nn.Sequential( |
|
nn.Conv2d(in_channels, 256, kernel_size=3, padding=1), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout2d(p=0.2) |
|
) |
|
self.attention1 = ChannelAttention(256) |
|
|
|
|
|
self.branch2 = nn.Sequential( |
|
nn.Conv2d(in_channels, 256, kernel_size=5, padding=2), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout2d(p=0.2) |
|
) |
|
self.attention2 = ChannelAttention(256) |
|
|
|
|
|
self.conv_reduce = nn.Conv2d(512, num_classes, kernel_size=1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
|
|
out1 = self.branch1(x) |
|
out1 = out1 * self.attention1(out1) |
|
|
|
|
|
out2 = self.branch2(x) |
|
out2 = out2 * self.attention2(out2) |
|
|
|
|
|
concat_features = torch.cat([out1, out2], dim=1) |
|
|
|
|
|
count_map = self.sigmoid(self.conv_reduce(concat_features)) |
|
|
|
|
|
|
|
count_vector = torch.sum(count_map, dim=(2, 3)) |
|
|
|
return count_map, count_vector |
|
|
|
|
|
"""Positional Encoding""" |
|
class PositionalEncoding(nn.Module): |
|
""" |
|
Positional encoding for attention decoder |
|
""" |
|
def __init__(self, d_model, max_seq_len=1024): |
|
super(PositionalEncoding, self).__init__() |
|
self.d_model = d_model |
|
|
|
|
|
pe = torch.zeros(max_seq_len, d_model) |
|
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
self.register_buffer('pe', pe) |
|
|
|
def forward(self, x): |
|
|
|
b, h, w, _ = x.shape |
|
|
|
|
|
if h*w > self.pe.size(0): |
|
|
|
device = self.pe.device |
|
extended_pe = torch.zeros(h*w, self.d_model, device=device) |
|
position = torch.arange(0, h*w, dtype=torch.float, device=device).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, self.d_model, 2, device=device).float() * (-math.log(10000.0) / self.d_model)) |
|
|
|
extended_pe[:, 0::2] = torch.sin(position * div_term) |
|
extended_pe[:, 1::2] = torch.cos(position * div_term) |
|
|
|
pos_encoding = extended_pe.view(h, w, -1) |
|
else: |
|
|
|
pos_encoding = self.pe[:h*w].view(h, w, -1) |
|
|
|
pos_encoding = pos_encoding.unsqueeze(0).expand(b, -1, -1, -1) |
|
return pos_encoding |
|
|
|
|
|
"""Counting-combined Attentional Decoder""" |
|
class CCAD(nn.Module): |
|
""" |
|
Counting-Combined Attentional Decoder |
|
""" |
|
def __init__(self, input_channels, hidden_size, embedding_dim, num_classes, use_coverage=True): |
|
super(CCAD, self).__init__() |
|
|
|
self.hidden_size = hidden_size |
|
self.embedding_dim = embedding_dim |
|
self.use_coverage = use_coverage |
|
|
|
|
|
self.feature_proj = nn.Conv2d(input_channels, hidden_size * 2, kernel_size=1) |
|
|
|
|
|
self.pos_encoder = PositionalEncoding(hidden_size * 2) |
|
|
|
|
|
self.embedding = nn.Embedding(num_classes, embedding_dim) |
|
|
|
|
|
self.gru = nn.GRUCell(embedding_dim + hidden_size + num_classes, hidden_size) |
|
|
|
|
|
self.attention_w = nn.Linear(hidden_size * 2, hidden_size) |
|
self.attention_v = nn.Linear(hidden_size, 1) |
|
if use_coverage: |
|
self.coverage_proj = nn.Linear(1, hidden_size) |
|
|
|
|
|
self.out = nn.Linear(hidden_size + hidden_size + num_classes, num_classes) |
|
self.dropout = nn.Dropout(p=0.3) |
|
|
|
def forward(self, feature_map, count_vector, target=None, teacher_forcing_ratio=0.5, max_len=200): |
|
batch_size = feature_map.size(0) |
|
device = feature_map.device |
|
|
|
|
|
projected_features = self.feature_proj(feature_map) |
|
H, W = projected_features.size(2), projected_features.size(3) |
|
|
|
|
|
projected_features = projected_features.permute(0, 2, 3, 1).contiguous() |
|
|
|
|
|
pos_encoding = self.pos_encoder(projected_features) |
|
projected_features = projected_features + pos_encoding |
|
|
|
|
|
projected_features = projected_features.view(batch_size, H*W, -1) |
|
|
|
|
|
h_t = torch.zeros(batch_size, self.hidden_size, device=device) |
|
|
|
|
|
if self.use_coverage: |
|
coverage = torch.zeros(batch_size, H*W, 1, device=device) |
|
|
|
|
|
y_t_1 = torch.ones(batch_size, dtype=torch.long, device=device) |
|
|
|
|
|
if target is not None: |
|
max_len = target.size(1) |
|
|
|
|
|
outputs = torch.zeros(batch_size, max_len, self.embedding.num_embeddings, device=device) |
|
|
|
for t in range(max_len): |
|
|
|
embedded = self.embedding(y_t_1) |
|
|
|
|
|
attention_input = self.attention_w(projected_features) |
|
|
|
|
|
if self.use_coverage: |
|
coverage_input = self.coverage_proj(coverage.float()) |
|
attention_input = attention_input + coverage_input |
|
|
|
|
|
h_expanded = h_t.unsqueeze(1).expand(-1, H*W, -1) |
|
attention_input = torch.tanh(attention_input + h_expanded) |
|
|
|
|
|
e_t = self.attention_v(attention_input).squeeze(-1) |
|
alpha_t = F.softmax(e_t, dim=1) |
|
|
|
|
|
if self.use_coverage: |
|
coverage = coverage + alpha_t.unsqueeze(-1) |
|
|
|
|
|
alpha_t = alpha_t.unsqueeze(1) |
|
context = torch.bmm(alpha_t, projected_features).squeeze(1) |
|
context = context[:, :self.hidden_size] |
|
|
|
|
|
gru_input = torch.cat([embedded, context, count_vector], dim=1) |
|
|
|
|
|
h_t = self.gru(gru_input, h_t) |
|
|
|
|
|
output = self.out(torch.cat([h_t, context, count_vector], dim=1)) |
|
outputs[:, t] = output |
|
|
|
|
|
if target is not None and torch.rand(1).item() < teacher_forcing_ratio: |
|
y_t_1 = target[:, t] |
|
else: |
|
|
|
_, y_t_1 = output.max(1) |
|
|
|
return outputs |
|
|
|
|
|
"""Full model CAN (Counting-Aware Network)""" |
|
class CAN(nn.Module): |
|
""" |
|
Counting-Aware Network for handwritten mathematical expression recognition |
|
""" |
|
def __init__(self, num_classes, backbone=None, hidden_size=256, embedding_dim=256, use_coverage=True): |
|
super(CAN, self).__init__() |
|
|
|
|
|
if backbone is None: |
|
self.backbone = DenseNetBackbone() |
|
else: |
|
self.backbone = backbone |
|
backbone_channels = self.backbone.out_channels |
|
|
|
|
|
self.mscm = MSCM(backbone_channels, num_classes) |
|
|
|
|
|
self.decoder = CCAD( |
|
input_channels=backbone_channels, |
|
hidden_size=hidden_size, |
|
embedding_dim=embedding_dim, |
|
num_classes=num_classes, |
|
use_coverage=use_coverage |
|
) |
|
|
|
|
|
self.hidden_size = hidden_size |
|
self.embedding_dim = embedding_dim |
|
self.num_classes = num_classes |
|
self.use_coverage = use_coverage |
|
|
|
def init_hidden_state(self, visual_features): |
|
""" |
|
Initialize hidden state and cell state for LSTM |
|
|
|
Args: |
|
visual_features: Visual features from backbone |
|
|
|
Returns: |
|
h, c: Initial hidden and cell states |
|
""" |
|
batch_size = visual_features.size(0) |
|
device = visual_features.device |
|
|
|
|
|
h = torch.zeros(1, batch_size, self.hidden_size, device=device) |
|
c = torch.zeros(1, batch_size, self.hidden_size, device=device) |
|
|
|
return h, c |
|
|
|
def forward(self, x, target=None, teacher_forcing_ratio=0.5): |
|
|
|
features = self.backbone(x) |
|
|
|
|
|
count_map, count_vector = self.mscm(features) |
|
|
|
|
|
outputs = self.decoder(features, count_vector, target, teacher_forcing_ratio) |
|
|
|
return outputs, count_vector |
|
|
|
def calculate_loss(self, outputs, targets, count_vectors, count_targets, lambda_count=0.01): |
|
""" |
|
Compute the combined loss function for CAN |
|
|
|
Args: |
|
outputs: Predicted output sequence from decoder |
|
targets: Actual target sequence |
|
count_vectors: Predicted count vector |
|
count_targets: Actual target count vector |
|
lambda_count: Weight for counting loss |
|
|
|
Returns: |
|
Total loss: L = L_cls + λ * L_counting |
|
""" |
|
|
|
L_cls = F.cross_entropy(outputs.view(-1, outputs.size(-1)), targets.view(-1)) |
|
|
|
|
|
L_counting = F.smooth_l1_loss(count_vectors / self.num_classes, count_targets / self.num_classes) |
|
|
|
|
|
total_loss = L_cls + lambda_count * L_counting |
|
|
|
return total_loss, L_cls, L_counting |
|
|
|
def recognize(self, images, max_length=150, start_token=None, end_token=None, beam_width=5): |
|
""" |
|
Recognize the handwritten expression using beam search (batch_size=1 only). |
|
|
|
Args: |
|
images: Input image tensor, shape (1, channels, height, width) |
|
max_length: Maximum length of the output sequence |
|
start_token: Start token index |
|
end_token: End token index |
|
beam_width: Beam width for beam search |
|
|
|
Returns: |
|
best_sequence: List of token indices |
|
attention_weights: List of attention weights for visualization |
|
""" |
|
if images.size(0) != 1: |
|
raise ValueError("Beam search is implemented only for batch_size=1") |
|
|
|
device = images.device |
|
|
|
|
|
visual_features = self.backbone(images) |
|
|
|
|
|
_, count_vector = self.mscm(visual_features) |
|
|
|
|
|
projected_features = self.decoder.feature_proj(visual_features) |
|
H, W = projected_features.size(2), projected_features.size(3) |
|
projected_features = projected_features.permute(0, 2, 3, 1).contiguous() |
|
pos_encoding = self.decoder.pos_encoder(projected_features) |
|
projected_features = projected_features + pos_encoding |
|
projected_features = projected_features.view(1, H*W, -1) |
|
|
|
|
|
beam_sequences = [torch.tensor([start_token], device=device)] * beam_width |
|
beam_scores = torch.zeros(beam_width, device=device) |
|
h_t = torch.zeros(beam_width, self.hidden_size, device=device) |
|
if self.use_coverage: |
|
coverage = torch.zeros(beam_width, H*W, device=device) |
|
|
|
all_attention_weights = [] |
|
|
|
for step in range(max_length): |
|
|
|
current_tokens = torch.tensor([seq[-1] for seq in beam_sequences], device=device) |
|
|
|
|
|
embedded = self.decoder.embedding(current_tokens) |
|
|
|
|
|
attention_input = self.decoder.attention_w(projected_features.expand(beam_width, -1, -1)) |
|
if self.use_coverage: |
|
coverage_input = self.decoder.coverage_proj(coverage.unsqueeze(-1)) |
|
attention_input = attention_input + coverage_input |
|
|
|
h_expanded = h_t.unsqueeze(1).expand(-1, H*W, -1) |
|
attention_input = torch.tanh(attention_input + h_expanded) |
|
|
|
e_t = self.decoder.attention_v(attention_input).squeeze(-1) |
|
alpha_t = F.softmax(e_t, dim=1) |
|
|
|
all_attention_weights.append(alpha_t.detach()) |
|
|
|
if self.use_coverage: |
|
coverage = coverage + alpha_t |
|
|
|
context = torch.bmm(alpha_t.unsqueeze(1), projected_features.expand(beam_width, -1, -1)).squeeze(1) |
|
context = context[:, :self.hidden_size] |
|
|
|
|
|
count_vector_expanded = count_vector.expand(beam_width, -1) |
|
|
|
gru_input = torch.cat([embedded, context, count_vector_expanded], dim=1) |
|
|
|
h_t = self.decoder.gru(gru_input, h_t) |
|
|
|
output = self.decoder.out(torch.cat([h_t, context, count_vector_expanded], dim=1)) |
|
scores = F.log_softmax(output, dim=1) |
|
|
|
|
|
new_beam_scores = beam_scores.unsqueeze(1) + scores |
|
new_beam_scores_flat = new_beam_scores.view(-1) |
|
|
|
|
|
topk_scores, topk_indices = new_beam_scores_flat.topk(beam_width) |
|
|
|
|
|
beam_indices = topk_indices // self.num_classes |
|
token_indices = topk_indices % self.num_classes |
|
|
|
|
|
new_beam_sequences = [] |
|
new_h_t = [] |
|
if self.use_coverage: |
|
new_coverage = [] |
|
for i in range(beam_width): |
|
prev_beam_idx = beam_indices[i].item() |
|
token = token_indices[i].item() |
|
new_seq = torch.cat([beam_sequences[prev_beam_idx], torch.tensor([token], device=device)]) |
|
new_beam_sequences.append(new_seq) |
|
new_h_t.append(h_t[prev_beam_idx]) |
|
if self.use_coverage: |
|
new_coverage.append(coverage[prev_beam_idx]) |
|
|
|
|
|
beam_sequences = new_beam_sequences |
|
beam_scores = topk_scores |
|
h_t = torch.stack(new_h_t) |
|
if self.use_coverage: |
|
coverage = torch.stack(new_coverage) |
|
|
|
|
|
best_idx = beam_scores.argmax() |
|
best_sequence = beam_sequences[best_idx].tolist() |
|
|
|
|
|
if best_sequence[0] == start_token: |
|
best_sequence = best_sequence[1:] |
|
if end_token in best_sequence: |
|
end_idx = best_sequence.index(end_token) |
|
best_sequence = best_sequence[:end_idx] |
|
|
|
return best_sequence, all_attention_weights |
|
|
|
|
|
def create_can_model(num_classes, hidden_size=256, embedding_dim=256, use_coverage=True, pretrained_backbone=False, backbone_type='densenet'): |
|
""" |
|
Create CAN model with either DenseNet or ResNet backbone |
|
|
|
Args: |
|
num_classes: Number of symbol classes |
|
pretrained_backbone: Whether to use a pretrained backbone |
|
backbone_type: Type of backbone to use ('densenet' or 'resnet') |
|
|
|
Returns: |
|
CAN model |
|
""" |
|
|
|
if backbone_type == 'densenet': |
|
if pretrained_backbone: |
|
densenet = models.densenet121(pretrained=True) |
|
backbone = DenseNetFeatureExtractor(densenet, out_channels=684) |
|
else: |
|
backbone = DenseNetBackbone() |
|
elif backbone_type == 'resnet': |
|
if pretrained_backbone: |
|
resnet = models.resnet50(pretrained=True) |
|
backbone = ResNetFeatureExtractor(resnet, out_channels=684) |
|
else: |
|
backbone = ResNetBackbone(block_type='bottleneck', layers=[3, 4, 6, 3]) |
|
else: |
|
raise ValueError(f"Unknown backbone type: {backbone_type}") |
|
|
|
|
|
model = CAN( |
|
num_classes=num_classes, |
|
backbone=backbone, |
|
hidden_size=hidden_size, |
|
embedding_dim=embedding_dim, |
|
use_coverage=use_coverage |
|
) |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|