safinal's picture
Create model.py
bfa1c41 verified
raw
history blame
767 Bytes
import torch
import open_clip
class Model(torch.nn.Module):
def __init__(self, model_name, pretrained) -> None:
super().__init__()
self.tokenizer = open_clip.get_tokenizer(model_name)
self.feature_extractor, _, self.processor = open_clip.create_model_and_transforms(
model_name=model_name,
pretrained=pretrained
)
self.set_param_trainable_mode(module=self.feature_extractor, status=False)
def set_param_trainable_mode(self, module, status):
for param in module.parameters():
param.requires_grad = status
def save(self, path):
torch.save(self.state_dict(), path)
def load(self, path):
self.load_state_dict(torch.load(path, weights_only=True))