Spaces:
Runtime error
Runtime error
import torch | |
import open_clip | |
class Model(torch.nn.Module): | |
def __init__(self, model_name, pretrained) -> None: | |
super().__init__() | |
self.tokenizer = open_clip.get_tokenizer(model_name) | |
self.feature_extractor, _, self.processor = open_clip.create_model_and_transforms( | |
model_name=model_name, | |
pretrained=pretrained | |
) | |
self.set_param_trainable_mode(module=self.feature_extractor, status=False) | |
def set_param_trainable_mode(self, module, status): | |
for param in module.parameters(): | |
param.requires_grad = status | |
def save(self, path): | |
torch.save(self.state_dict(), path) | |
def load(self, path): | |
self.load_state_dict(torch.load(path, weights_only=True, map_location=torch.device('cpu'))) |