File size: 3,883 Bytes
3026a03
 
7970c81
 
3026a03
 
 
 
cfbf9d9
bde4ec6
3026a03
 
 
 
 
 
 
 
 
1729d9d
3026a03
 
 
 
 
 
 
 
 
 
b07d75e
3026a03
 
 
c520c64
 
 
3026a03
 
 
 
a9d9852
77ade98
a9d9852
 
 
798ede0
a9d9852
b07d75e
 
 
 
94b146f
87d19ad
bde4ec6
f1b5f5a
bde4ec6
0abd296
a10a91e
 
 
 
be95217
 
efb10b4
be95217
87d19ad
 
 
3026a03
d699927
b6e6a40
 
9202e10
 
3026a03
 
 
 
ef2044c
 
3026a03
 
ef2044c
3026a03
ef2044c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import json
import os
import time

def load_model_from_config(config_path, model_name, device='cuda'):
    # Load the config file
    config = OmegaConf.load(config_path)
    
    # Instantiate the model
    model = instantiate_from_config(config.model)
    
    # Download the model file from Hugging Face
    model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors", token=os.getenv('HF_TOKEN'))
    
    print(f"Loading model from {model_name}")
    # Load the state dict
    state_dict = torch.load(model_file, map_location='cpu')
    model.load_state_dict(state_dict, strict=False)
    
    model.to(device)
    model.eval()
    return model

def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_maps=None, leftclick_maps=None):
    sampler = DDIMSampler(model)
    
    with torch.no_grad():
        #u_dict = {'c_crossattn': "", 'c_concat': image_sequence}
        #uc = model.get_learned_conditioning(u_dict)
        #uc = model.enc_concat_seq(uc, u_dict, 'c_concat')
        
        c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
        c = model.get_learned_conditioning(c_dict)
        c = model.enc_concat_seq(c, c_dict, 'c_concat')
        # Zero out the corresponding subtensors in c_concat for padding images
        padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(0)
        print (padding_mask)
        padding_mask = padding_mask.repeat(1, 4)  # Repeat mask 4 times for each projected channel
        print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
        c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1))  # Zero out the corresponding features
        
        if pos_maps is not None:
            pos_map = pos_maps[0]
            leftclick_map = torch.cat(leftclick_maps, dim=0)
            print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
            c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)

        print ('sleeping')
        #time.sleep(120)
        print ('finished sleeping')
        DDPM = False
        if DDPM:
            samples_ddim = model.p_sample_loop(cond=c, shape=[1, 4, 64, 64], return_intermediates=False, verbose=True)
        else:
            samples_ddim, _ = sampler.sample(S=8,
                                         conditioning=c,
                                         batch_size=1,
                                         shape=[4, 64, 64],
                                         verbose=False)
        #                                 unconditional_guidance_scale=5.0,
        #                                 unconditional_conditioning=uc,
        #                                 eta=0)
        
        x_samples_ddim = model.decode_first_stage(samples_ddim)
        #x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
        #x_samples_ddim = model.decode_first_stage(x_samples_ddim)
        #x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
        
        return x_samples_ddim.squeeze(0).cpu().numpy()

# Global variables for model and device
#model = None
#device = None

def initialize_model(config_path, model_name):
    #global model, device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model_from_config(config_path, model_name, device)
    return model