|
from .clip import clip |
|
from PIL import Image |
|
import torch.nn as nn |
|
|
|
|
|
CHANNELS = { |
|
"RN50" : 1024, |
|
"ViT-L/14" : 768 |
|
} |
|
|
|
class CLIPModel(nn.Module): |
|
def __init__(self, name, num_classes=1): |
|
super(CLIPModel, self).__init__() |
|
|
|
self.model, self.preprocess = clip.load(name, device="cpu") |
|
self.fc = nn.Linear( CHANNELS[name], num_classes ) |
|
|
|
|
|
def forward(self, x, return_feature=False): |
|
features = self.model.encode_image(x) |
|
|
|
""" |
|
使用的是ViT-Large, 共24层 |
|
选择第24、22、20层的[cls]feature做加权平均 |
|
""" |
|
if return_feature: |
|
return features['after_projection'] |
|
|
|
|
|
|
|
|
|
|
|
features = features['res_output'] |
|
return self.fc(features) |
|
|
|
|