File size: 239 Bytes
9b43cf7
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11

import torch


def save_model(model):
    torch.save(model.state_dict(), 'model_weights.pth')


def load_model(model):
    return model.load_state_dict(torch.load('./models/model_weights_27_styles.pth', map_location=torch.device('cpu')))