Wan2GP / preprocessing /midas /base_model.py
zxymimi23451's picture
Upload 258 files
78360e7 verified
raw
history blame contribute delete
478 Bytes
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'), weights_only=True)
if 'optimizer' in parameters:
parameters = parameters['model']
self.load_state_dict(parameters)