Spaces:
Running
Running
File size: 1,385 Bytes
684943d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import os
import torch
import torch.nn as nn
def searchForMaxIteration(folder):
saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
return max(saved_iters)
class AppModel(nn.Module):
def __init__(self, num_images=1600):
super().__init__()
self.appear_ab = nn.Parameter(torch.zeros(num_images, 2).cuda())
self.optimizer = torch.optim.Adam([
{'params': self.appear_ab, 'lr': 0.001, "name": "appear_ab"},
], betas=(0.9, 0.99))
def save_weights(self, model_path, iteration):
out_weights_path = os.path.join(model_path, "app_model/iteration_{}".format(iteration))
os.makedirs(out_weights_path, exist_ok=True)
print(f"save app model. path: {out_weights_path}")
torch.save(self.state_dict(), os.path.join(out_weights_path, 'app.pth'))
def load_weights(self, model_path, iteration=-1):
if iteration == -1:
loaded_iter = searchForMaxIteration(os.path.join(model_path, "app_model"))
else:
loaded_iter = iteration
weights_path = os.path.join(model_path, "app_model/iteration_{}/app.pth".format(loaded_iter))
state_dict = torch.load(weights_path)
self.load_state_dict(state_dict)
def freeze(self):
self.appear_ab.requires_grad_(False)
|