Spaces:
Running
Running
import os | |
import torch | |
import torch.nn as nn | |
def searchForMaxIteration(folder): | |
saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] | |
return max(saved_iters) | |
class AppModel(nn.Module): | |
def __init__(self, num_images=1600): | |
super().__init__() | |
self.appear_ab = nn.Parameter(torch.zeros(num_images, 2).cuda()) | |
self.optimizer = torch.optim.Adam([ | |
{'params': self.appear_ab, 'lr': 0.001, "name": "appear_ab"}, | |
], betas=(0.9, 0.99)) | |
def save_weights(self, model_path, iteration): | |
out_weights_path = os.path.join(model_path, "app_model/iteration_{}".format(iteration)) | |
os.makedirs(out_weights_path, exist_ok=True) | |
print(f"save app model. path: {out_weights_path}") | |
torch.save(self.state_dict(), os.path.join(out_weights_path, 'app.pth')) | |
def load_weights(self, model_path, iteration=-1): | |
if iteration == -1: | |
loaded_iter = searchForMaxIteration(os.path.join(model_path, "app_model")) | |
else: | |
loaded_iter = iteration | |
weights_path = os.path.join(model_path, "app_model/iteration_{}/app.pth".format(loaded_iter)) | |
state_dict = torch.load(weights_path) | |
self.load_state_dict(state_dict) | |
def freeze(self): | |
self.appear_ab.requires_grad_(False) | |