File size: 592 Bytes
8c0b0b5
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
from ...sgm.models.diffusion import DiffusionEngine
from ...sgm.util import instantiate_from_config
import copy

class SUPIRModel(DiffusionEngine):
    def __init__(self, control_stage_config, ae_dtype='fp32', diffusion_dtype='fp32', p_p='', n_p='', *args, **kwargs):
        super().__init__(*args, **kwargs)
        control_model = instantiate_from_config(control_stage_config)
        self.model.load_control_model(control_model)
        self.first_stage_model.denoise_encoder = copy.deepcopy(self.first_stage_model.encoder)
        self.sampler_config = kwargs['sampler_config']