File size: 2,343 Bytes
b7f710c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn

class FaceClassifier(nn.Module):
    """Face classification model with a configurable head."""
    def __init__(self, base_model, num_classes, model_name, model_configs):
        super(FaceClassifier, self).__init__()
        self.base_model = base_model
        self.model_name = model_name
        
        # Determine the feature extraction method and output shape
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, model_configs[model_name]['resolution'], model_configs[model_name]['resolution'])
            features = base_model(dummy_input)
            if len(features.shape) == 4:  # Spatial feature map (batch, channels, height, width)
                in_channels = features.shape[1]
                self.feature_type = 'spatial'
                self.feature_dim = in_channels
            elif len(features.shape) == 2:  # Flattened feature vector (batch, features)
                in_channels = features.shape[1]
                self.feature_type = 'flat'
                self.feature_dim = in_channels
            else:
                raise ValueError(f"Unexpected feature shape from base model {model_name}: {features.shape}")

        # Define the classifier head based on feature type
        if self.feature_type == 'flat' or 'vit' in model_name:
            self.conv_head = nn.Sequential(
                nn.Linear(self.feature_dim, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(512, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
                nn.Linear(256, num_classes)
            )
        else:
            self.conv_head = nn.Sequential(
                nn.Conv2d(self.feature_dim, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                nn.Dropout2d(0.5),
                nn.Conv2d(512, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(256, num_classes)
            )

    def forward(self, x):
        features = self.base_model(x)
        output = self.conv_head(features)
        return output