safinal commited on
Commit
bfa1c41
Β·
verified Β·
1 Parent(s): ffd2453

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +24 -0
model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open_clip
3
+
4
+
5
+ class Model(torch.nn.Module):
6
+ def __init__(self, model_name, pretrained) -> None:
7
+ super().__init__()
8
+ self.tokenizer = open_clip.get_tokenizer(model_name)
9
+ self.feature_extractor, _, self.processor = open_clip.create_model_and_transforms(
10
+ model_name=model_name,
11
+ pretrained=pretrained
12
+ )
13
+ self.set_param_trainable_mode(module=self.feature_extractor, status=False)
14
+
15
+
16
+ def set_param_trainable_mode(self, module, status):
17
+ for param in module.parameters():
18
+ param.requires_grad = status
19
+
20
+ def save(self, path):
21
+ torch.save(self.state_dict(), path)
22
+
23
+ def load(self, path):
24
+ self.load_state_dict(torch.load(path, weights_only=True))