import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models import math """Custom DenseNet Backbone""" class DenseBlock(nn.Module): """ Basic DenseNet block """ def __init__(self, in_channels, growth_rate, num_layers): super(DenseBlock, self).__init__() self.layers = nn.ModuleList() for i in range(num_layers): self.layers.append(self._make_layer(in_channels + i * growth_rate, growth_rate)) def _make_layer(self, in_channels, growth_rate): layer = nn.Sequential( nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False), nn.BatchNorm2d(4 * growth_rate), nn.ReLU(inplace=True), nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) ) return layer def forward(self, x): features = [x] for layer in self.layers: new_feature = layer(torch.cat(features, dim=1)) features.append(new_feature) return torch.cat(features, dim=1) class TransitionLayer(nn.Module): """ Transition layer between DenseBlocks """ def __init__(self, in_channels, out_channels): super(TransitionLayer, self).__init__() self.transition = nn.Sequential( nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.AvgPool2d(kernel_size=2, stride=2) ) def forward(self, x): return self.transition(x) class DenseNetBackbone(nn.Module): """ DenseNet backbone for CAN """ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64): super(DenseNetBackbone, self).__init__() # Initial layer self.features = nn.Sequential( nn.Conv2d(1, num_init_features, kernel_size=7, stride=2, padding=3, bias=False), nn.BatchNorm2d(num_init_features), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) ) # DenseBlocks num_features = num_init_features for i, num_layers in enumerate(block_config): block = DenseBlock(num_features, growth_rate, num_layers) self.features.add_module(f'denseblock{i+1}', block) num_features = num_features + growth_rate * num_layers if i != len(block_config) - 1: trans = TransitionLayer(num_features, num_features // 2) self.features.add_module(f'transition{i+1}', trans) num_features = num_features // 2 # Final processing self.features.add_module('norm5', nn.BatchNorm2d(num_features)) self.features.add_module('relu5', nn.ReLU(inplace=True)) self.out_channels = num_features # 684 (with default configuration) def forward(self, x): return self.features(x) """Pretrained DenseNet""" class DenseNetFeatureExtractor(nn.Module): def __init__(self, densenet_model, out_channels=684): super().__init__() # Change input conv to 1 channel self.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) # Copy pretrained weights (average over RGB channels) self.conv0.weight.data = densenet_model.features.conv0.weight.data.mean(dim=1, keepdim=True) self.features = densenet_model.features self.out_channels = out_channels # Add a 1x1 conv to match your expected output channels if needed self.final_conv = nn.Conv2d(1024, out_channels, kernel_size=1) self.final_bn = nn.BatchNorm2d(out_channels) self.final_relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv0(x) x = self.features.norm0(x) x = self.features.relu0(x) x = self.features.pool0(x) x = self.features.denseblock1(x) x = self.features.transition1(x) x = self.features.denseblock2(x) x = self.features.transition2(x) x = self.features.denseblock3(x) x = self.features.transition3(x) x = self.features.denseblock4(x) x = self.features.norm5(x) x = self.final_conv(x) x = self.final_bn(x) x = self.final_relu(x) return x """Custom ResNet Backbone""" class BasicBlock(nn.Module): """ Basic ResNet block """ expansion = 1 def __init__(self, in_channels, out_channels, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels * self.expansion: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += self.shortcut(identity) out = self.relu(out) return out class Bottleneck(nn.Module): """ Bottleneck ResNet block """ expansion = 4 def __init__(self, in_channels, out_channels, stride=1): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) self.relu = nn.ReLU(inplace=True) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels * self.expansion: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out += self.shortcut(identity) out = self.relu(out) return out class ResNetBackbone(nn.Module): """ ResNet backbone for CAN model, designed to output similar dimensions as DenseNet """ def __init__(self, block_type='bottleneck', layers=[3, 4, 6, 3], num_init_features=64): super(ResNetBackbone, self).__init__() # Initial layer self.conv1 = nn.Conv2d(1, num_init_features, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(num_init_features) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Define block type if block_type == 'basic': block = BasicBlock expansion = 1 elif block_type == 'bottleneck': block = Bottleneck expansion = 4 else: raise ValueError(f"Unknown block type: {block_type}") # Create layers self.layer1 = self._make_layer(block, num_init_features, 64, layers[0], stride=1) self.layer2 = self._make_layer(block, 64 * expansion, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 128 * expansion, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 256 * expansion, 512, layers[3], stride=2) # Final processing to match DenseNet output channels self.final_conv = nn.Conv2d(512 * expansion, 684, kernel_size=1) self.final_bn = nn.BatchNorm2d(684) self.final_relu = nn.ReLU(inplace=True) self.out_channels = 684 # Match DenseNet output channels # Initialize weights self._initialize_weights() def _make_layer(self, block, in_channels, out_channels, num_blocks, stride): layers = [] layers.append(block(in_channels, out_channels, stride)) for _ in range(1, num_blocks): layers.append(block(out_channels * block.expansion, out_channels)) return nn.Sequential(*layers) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.final_conv(x) x = self.final_bn(x) x = self.final_relu(x) return x """Pretrained ResNet""" class ResNetFeatureExtractor(nn.Module): def __init__(self, resnet_model, out_channels=684): super().__init__() # Change input conv to 1 channel self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) self.conv1.weight.data = resnet_model.conv1.weight.data.sum(dim=1, keepdim=True) # average weights if needed self.bn1 = resnet_model.bn1 self.relu = resnet_model.relu self.maxpool = resnet_model.maxpool self.layer1 = resnet_model.layer1 self.layer2 = resnet_model.layer2 self.layer3 = resnet_model.layer3 self.layer4 = resnet_model.layer4 # Add a 1x1 conv to match DenseNet output channels if needed self.final_conv = nn.Conv2d(2048, out_channels, kernel_size=1) self.final_bn = nn.BatchNorm2d(out_channels) self.final_relu = nn.ReLU(inplace=True) self.out_channels = out_channels def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.final_conv(x) x = self.final_bn(x) x = self.final_relu(x) return x """Channel Attention""" class ChannelAttention(nn.Module): """ Channel-wise attention mechanism """ def __init__(self, in_channels, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, bias=False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) out = avg_out + max_out return self.sigmoid(out) """Multi-scale Couting Module""" class MSCM(nn.Module): """ Multi-Scale Counting Module """ def __init__(self, in_channels, num_classes): super(MSCM, self).__init__() # Branch 1: 3x3 kernel self.branch1 = nn.Sequential( nn.Conv2d(in_channels, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Dropout2d(p=0.2) ) self.attention1 = ChannelAttention(256) # Branch 2: 5x5 kernel self.branch2 = nn.Sequential( nn.Conv2d(in_channels, 256, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.Dropout2d(p=0.2) ) self.attention2 = ChannelAttention(256) # 1x1 Conv layer to reduce channels and create counting map self.conv_reduce = nn.Conv2d(512, num_classes, kernel_size=1) self.sigmoid = nn.Sigmoid() def forward(self, x): # Process branch 1 out1 = self.branch1(x) out1 = out1 * self.attention1(out1) # Process branch 2 out2 = self.branch2(x) out2 = out2 * self.attention2(out2) # Concatenate features from both branches concat_features = torch.cat([out1, out2], dim=1) # Shape: B x 512 x H x W # Create counting map count_map = self.sigmoid(self.conv_reduce(concat_features)) # Shape: B x C x H x W # Apply sum-pooling to create 1D counting vector # Sum over the entire feature map along height and width count_vector = torch.sum(count_map, dim=(2, 3)) # Shape: B x C return count_map, count_vector """Positional Encoding""" class PositionalEncoding(nn.Module): """ Positional encoding for attention decoder """ def __init__(self, d_model, max_seq_len=1024): super(PositionalEncoding, self).__init__() self.d_model = d_model # Create positional encoding matrix pe = torch.zeros(max_seq_len, d_model) position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): # x shape: B x H x W x d_model b, h, w, _ = x.shape # Ensure we have enough positional encodings for the feature map size if h*w > self.pe.size(0): #type: ignore # Dynamically extend positional encodings if needed device = self.pe.device extended_pe = torch.zeros(h*w, self.d_model, device=device) #type: ignore position = torch.arange(0, h*w, dtype=torch.float, device=device).unsqueeze(1) #type: ignore div_term = torch.exp(torch.arange(0, self.d_model, 2, device=device).float() * (-math.log(10000.0) / self.d_model)) #type: ignore extended_pe[:, 0::2] = torch.sin(position * div_term) extended_pe[:, 1::2] = torch.cos(position * div_term) pos_encoding = extended_pe.view(h, w, -1) else: # Use pre-computed positional encodings pos_encoding = self.pe[:h*w].view(h, w, -1) #type: ignore pos_encoding = pos_encoding.unsqueeze(0).expand(b, -1, -1, -1) # B x H x W x d_model return pos_encoding """Counting-combined Attentional Decoder""" class CCAD(nn.Module): """ Counting-Combined Attentional Decoder """ def __init__(self, input_channels, hidden_size, embedding_dim, num_classes, use_coverage=True): super(CCAD, self).__init__() self.hidden_size = hidden_size self.embedding_dim = embedding_dim self.use_coverage = use_coverage # Input layer to reduce feature map self.feature_proj = nn.Conv2d(input_channels, hidden_size * 2, kernel_size=1) # Positional encoding self.pos_encoder = PositionalEncoding(hidden_size * 2) # Embedding layer for output symbols self.embedding = nn.Embedding(num_classes, embedding_dim) # GRU cell self.gru = nn.GRUCell(embedding_dim + hidden_size + num_classes, hidden_size) # Attention self.attention_w = nn.Linear(hidden_size * 2, hidden_size) self.attention_v = nn.Linear(hidden_size, 1) if use_coverage: self.coverage_proj = nn.Linear(1, hidden_size) # Output layer self.out = nn.Linear(hidden_size + hidden_size + num_classes, num_classes) self.dropout = nn.Dropout(p=0.3) def forward(self, feature_map, count_vector, target=None, teacher_forcing_ratio=0.5, max_len=200): batch_size = feature_map.size(0) device = feature_map.device # Transform feature map projected_features = self.feature_proj(feature_map) # B x 2*hidden_size x H x W H, W = projected_features.size(2), projected_features.size(3) # Reshape feature map to B x H*W x 2*hidden_size projected_features = projected_features.permute(0, 2, 3, 1).contiguous() # B x H x W x 2*hidden_size # Add positional encoding pos_encoding = self.pos_encoder(projected_features) # B x H x W x 2*hidden_size projected_features = projected_features + pos_encoding # Reshape for attention processing projected_features = projected_features.view(batch_size, H*W, -1) # B x H*W x 2*hidden_size # Initialize initial hidden state h_t = torch.zeros(batch_size, self.hidden_size, device=device) # Initialize coverage attention if used if self.use_coverage: coverage = torch.zeros(batch_size, H*W, 1, device=device) # First token y_t_1 = torch.ones(batch_size, dtype=torch.long, device=device) # Prepare target sequence if provided if target is not None: max_len = target.size(1) # Array to store predictions outputs = torch.zeros(batch_size, max_len, self.embedding.num_embeddings, device=device) for t in range(max_len): # Apply embedding to the previous symbol embedded = self.embedding(y_t_1) # B x embedding_dim # Compute attention attention_input = self.attention_w(projected_features) # B x H*W x hidden_size # Add coverage attention if used if self.use_coverage: coverage_input = self.coverage_proj(coverage.float()) #type: ignore attention_input = attention_input + coverage_input # Add hidden state to attention h_expanded = h_t.unsqueeze(1).expand(-1, H*W, -1) # B x H*W x hidden_size attention_input = torch.tanh(attention_input + h_expanded) # Compute attention weights e_t = self.attention_v(attention_input).squeeze(-1) # B x H*W alpha_t = F.softmax(e_t, dim=1) # B x H*W # Update coverage if used if self.use_coverage: coverage = coverage + alpha_t.unsqueeze(-1) #type: ignore # Compute context vector alpha_t = alpha_t.unsqueeze(1) # B x 1 x H*W context = torch.bmm(alpha_t, projected_features).squeeze(1) # B x 2*hidden_size context = context[:, :self.hidden_size] # Take the first half as context vector # Combine embedding, context vector, and count vector gru_input = torch.cat([embedded, context, count_vector], dim=1) # Update hidden state h_t = self.gru(gru_input, h_t) # Predict output symbol output = self.out(torch.cat([h_t, context, count_vector], dim=1)) outputs[:, t] = output # Decide the next input symbol if target is not None and torch.rand(1).item() < teacher_forcing_ratio: y_t_1 = target[:, t] else: # Greedy decoding _, y_t_1 = output.max(1) return outputs """Full model CAN (Counting-Aware Network)""" class CAN(nn.Module): """ Counting-Aware Network for handwritten mathematical expression recognition """ def __init__(self, num_classes, backbone=None, hidden_size=256, embedding_dim=256, use_coverage=True): super(CAN, self).__init__() # Backbone if backbone is None: self.backbone = DenseNetBackbone() else: self.backbone = backbone backbone_channels = self.backbone.out_channels # Multi-Scale Counting Module self.mscm = MSCM(backbone_channels, num_classes) # Counting-Combined Attentional Decoder self.decoder = CCAD( input_channels=backbone_channels, hidden_size=hidden_size, embedding_dim=embedding_dim, num_classes=num_classes, use_coverage=use_coverage ) # Save parameters for later use self.hidden_size = hidden_size self.embedding_dim = embedding_dim self.num_classes = num_classes self.use_coverage = use_coverage def init_hidden_state(self, visual_features): """ Initialize hidden state and cell state for LSTM Args: visual_features: Visual features from backbone Returns: h, c: Initial hidden and cell states """ batch_size = visual_features.size(0) device = visual_features.device # Initialize hidden state with zeros h = torch.zeros(1, batch_size, self.hidden_size, device=device) c = torch.zeros(1, batch_size, self.hidden_size, device=device) return h, c def forward(self, x, target=None, teacher_forcing_ratio=0.5): # Extract features from backbone features = self.backbone(x) # Compute count map and count vector from MSCM count_map, count_vector = self.mscm(features) # Decode with CCAD outputs = self.decoder(features, count_vector, target, teacher_forcing_ratio) return outputs, count_vector def calculate_loss(self, outputs, targets, count_vectors, count_targets, lambda_count=0.01): """ Compute the combined loss function for CAN Args: outputs: Predicted output sequence from decoder targets: Actual target sequence count_vectors: Predicted count vector count_targets: Actual target count vector lambda_count: Weight for counting loss Returns: Total loss: L = L_cls + λ * L_counting """ # Loss for decoder (cross entropy) L_cls = F.cross_entropy(outputs.view(-1, outputs.size(-1)), targets.view(-1)) # Loss for counting (MSE) L_counting = F.smooth_l1_loss(count_vectors / self.num_classes, count_targets / self.num_classes) # Total loss total_loss = L_cls + lambda_count * L_counting return total_loss, L_cls, L_counting def recognize(self, images, max_length=150, start_token=None, end_token=None, beam_width=5): """ Recognize the handwritten expression using beam search (batch_size=1 only). Args: images: Input image tensor, shape (1, channels, height, width) max_length: Maximum length of the output sequence start_token: Start token index end_token: End token index beam_width: Beam width for beam search Returns: best_sequence: List of token indices attention_weights: List of attention weights for visualization """ if images.size(0) != 1: raise ValueError("Beam search is implemented only for batch_size=1") device = images.device # Encode the image visual_features = self.backbone(images) # Get count vector _, count_vector = self.mscm(visual_features) # Prepare feature map for decoder projected_features = self.decoder.feature_proj(visual_features) # (1, 2*hidden_size, H, W) H, W = projected_features.size(2), projected_features.size(3) projected_features = projected_features.permute(0, 2, 3, 1).contiguous() # (1, H, W, 2*hidden_size) pos_encoding = self.decoder.pos_encoder(projected_features) # (1, H, W, 2*hidden_size) projected_features = projected_features + pos_encoding # (1, H, W, 2*hidden_size) projected_features = projected_features.view(1, H*W, -1) # (1, H*W, 2*hidden_size) # Initialize beams beam_sequences = [torch.tensor([start_token], device=device)] * beam_width # List of (seq_len) tensors beam_scores = torch.zeros(beam_width, device=device) # (beam_width) h_t = torch.zeros(beam_width, self.hidden_size, device=device) # (beam_width, hidden_size) if self.use_coverage: coverage = torch.zeros(beam_width, H*W, device=device) # (beam_width, H*W) all_attention_weights = [] for step in range(max_length): # Get current tokens for all beams current_tokens = torch.tensor([seq[-1] for seq in beam_sequences], device=device) # (beam_width) # Apply embedding embedded = self.decoder.embedding(current_tokens) # (beam_width, embedding_dim) # Compute attention for each beam attention_input = self.decoder.attention_w(projected_features.expand(beam_width, -1, -1)) # (beam_width, H*W, hidden_size) if self.use_coverage: coverage_input = self.decoder.coverage_proj(coverage.unsqueeze(-1)) # (beam_width, H*W, hidden_size) #type: ignore attention_input = attention_input + coverage_input h_expanded = h_t.unsqueeze(1).expand(-1, H*W, -1) # (beam_width, H*W, hidden_size) attention_input = torch.tanh(attention_input + h_expanded) e_t = self.decoder.attention_v(attention_input).squeeze(-1) # (beam_width, H*W) alpha_t = F.softmax(e_t, dim=1) # (beam_width, H*W) all_attention_weights.append(alpha_t.detach()) if self.use_coverage: coverage = coverage + alpha_t #type: ignore context = torch.bmm(alpha_t.unsqueeze(1), projected_features.expand(beam_width, -1, -1)).squeeze(1) # (beam_width, 2*hidden_size) context = context[:, :self.hidden_size] # (beam_width, hidden_size) # Expand count_vector to (beam_width, num_classes) count_vector_expanded = count_vector.expand(beam_width, -1) # (beam_width, num_classes) gru_input = torch.cat([embedded, context, count_vector_expanded], dim=1) # (beam_width, embedding_dim + hidden_size + num_classes) h_t = self.decoder.gru(gru_input, h_t) # (beam_width, hidden_size) output = self.decoder.out(torch.cat([h_t, context, count_vector_expanded], dim=1)) # (beam_width, num_classes) scores = F.log_softmax(output, dim=1) # (beam_width, num_classes) # Compute new scores for all beam-token combinations new_beam_scores = beam_scores.unsqueeze(1) + scores # (beam_width, num_classes) new_beam_scores_flat = new_beam_scores.view(-1) # (beam_width * num_classes) # Select top beam_width scores and indices topk_scores, topk_indices = new_beam_scores_flat.topk(beam_width) # Determine which beam and token each top score corresponds to beam_indices = topk_indices // self.num_classes # (beam_width) token_indices = topk_indices % self.num_classes # (beam_width) # Create new beam sequences and states new_beam_sequences = [] new_h_t = [] if self.use_coverage: new_coverage = [] for i in range(beam_width): prev_beam_idx = beam_indices[i].item() token = token_indices[i].item() new_seq = torch.cat([beam_sequences[prev_beam_idx], torch.tensor([token], device=device)]) #type: ignore new_beam_sequences.append(new_seq) new_h_t.append(h_t[prev_beam_idx]) if self.use_coverage: new_coverage.append(coverage[prev_beam_idx]) #type: ignore # Update beams beam_sequences = new_beam_sequences beam_scores = topk_scores h_t = torch.stack(new_h_t) if self.use_coverage: coverage = torch.stack(new_coverage) #type: ignore # Select the sequence with the highest score best_idx = beam_scores.argmax() best_sequence = beam_sequences[best_idx].tolist() # Remove and stop at if best_sequence[0] == start_token: best_sequence = best_sequence[1:] if end_token in best_sequence: end_idx = best_sequence.index(end_token) best_sequence = best_sequence[:end_idx] return best_sequence, all_attention_weights def create_can_model(num_classes, hidden_size=256, embedding_dim=256, use_coverage=True, pretrained_backbone=False, backbone_type='densenet'): """ Create CAN model with either DenseNet or ResNet backbone Args: num_classes: Number of symbol classes pretrained_backbone: Whether to use a pretrained backbone backbone_type: Type of backbone to use ('densenet' or 'resnet') Returns: CAN model """ # Create backbone if backbone_type == 'densenet': if pretrained_backbone: densenet = models.densenet121(pretrained=True) backbone = DenseNetFeatureExtractor(densenet, out_channels=684) else: backbone = DenseNetBackbone() elif backbone_type == 'resnet': if pretrained_backbone: resnet = models.resnet50(pretrained=True) backbone = ResNetFeatureExtractor(resnet, out_channels=684) else: backbone = ResNetBackbone(block_type='bottleneck', layers=[3, 4, 6, 3]) else: raise ValueError(f"Unknown backbone type: {backbone_type}") # Create model model = CAN( num_classes=num_classes, backbone=backbone, hidden_size=hidden_size, embedding_dim=embedding_dim, use_coverage=use_coverage ) return model # # Example usage # if __name__ == "__main__": # # Create CAN model with 101 symbol classes (example) # num_classes = 101 # Number of symbol classes + special tokens like , # model = create_can_model(num_classes) # # Create dummy input data # batch_size = 4 # input_image = torch.randn(batch_size, 1, 128, 384) # B x C x H x W # target = torch.randint(0, num_classes, (batch_size, 50)) # B x max_len # # Forward pass # outputs, count_vectors = model(input_image, target) # # Print output shapes # print(f"Outputs shape: {outputs.shape}") # B x max_len x num_classes # print(f"Count vectors shape: {count_vectors.shape}") # B x num_classes