File size: 801 Bytes
bfa1c41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38b5db9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import open_clip


class Model(torch.nn.Module):
    def __init__(self, model_name, pretrained) -> None:
        super().__init__()
        self.tokenizer = open_clip.get_tokenizer(model_name)
        self.feature_extractor, _, self.processor = open_clip.create_model_and_transforms(
            model_name=model_name,
            pretrained=pretrained
        )
        self.set_param_trainable_mode(module=self.feature_extractor, status=False)


    def set_param_trainable_mode(self, module, status):
        for param in module.parameters():
            param.requires_grad = status
    
    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path, weights_only=True, map_location=torch.device('cpu')))