File size: 4,839 Bytes
3026a03
 
7970c81
 
3026a03
 
 
 
cfbf9d9
bde4ec6
f0e8a7a
3026a03
 
 
 
 
 
 
 
 
1729d9d
3026a03
 
 
 
fcd17c7
3026a03
 
 
 
 
b07d75e
3026a03
 
 
c520c64
 
 
3026a03
 
fbec427
3026a03
83455bf
 
 
3026a03
a9d9852
77ade98
a9d9852
 
 
798ede0
a9d9852
b07d75e
 
 
 
32367a7
 
94b146f
87d19ad
bde4ec6
f1b5f5a
bde4ec6
abac87c
416b9ef
09eba18
 
32367a7
7f1aea2
32367a7
 
09eba18
a10a91e
4bb6f21
a10a91e
 
be95217
 
4bb6f21
be95217
87d19ad
 
 
09eba18
 
 
f660fd9
6cdc4a9
 
2db4576
926d3a7
09eba18
 
b6e6a40
 
9202e10
 
3026a03
 
 
 
ef2044c
 
3026a03
 
ef2044c
3026a03
ef2044c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import json
import os
import time
DEBUG = True

def load_model_from_config(config_path, model_name, device='cuda'):
    # Load the config file
    config = OmegaConf.load(config_path)
    
    # Instantiate the model
    model = instantiate_from_config(config.model)
    
    # Download the model file from Hugging Face
    model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors", token=os.getenv('HF_TOKEN'))
    
    print(f"Loading model from {model_name}")
    # Load the state dict
    state_dict = torch.load(model_file, map_location='cpu')
    model.load_state_dict(state_dict, strict=True)
    
    model.to(device)
    model.eval()
    return model

def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_maps=None, leftclick_maps=None):
    sampler = DDIMSampler(model)
    
    with torch.no_grad():
        #u_dict = {'c_crossattn': "", 'c_concat': image_sequence}
        #uc = model.get_learned_conditioning(u_dict)
        #uc = model.enc_concat_seq(uc, u_dict, 'c_concat')
        
        c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
        model.eval()
        c = model.get_learned_conditioning(c_dict)
        print (c['c_crossattn'].shape)
        print (c['c_crossattn'][0])
        print (prompt)
        c = model.enc_concat_seq(c, c_dict, 'c_concat')
        # Zero out the corresponding subtensors in c_concat for padding images
        padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(0)
        print (padding_mask)
        padding_mask = padding_mask.repeat(1, 4)  # Repeat mask 4 times for each projected channel
        print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
        c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1))  # Zero out the corresponding features
        
        if pos_maps is not None:
            pos_map = pos_maps[0]
            leftclick_map = torch.cat(leftclick_maps, dim=0)
            print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
            if DEBUG:
                c['c_concat'] = c['c_concat']*0
            c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)

        print ('sleeping')
        #time.sleep(120)
        print ('finished sleeping')
        DDPM = False
        DDPM = True

        if DEBUG:
            #c['c_concat'] = c['c_concat']*0
            print ('utils prompt', prompt, c['c_concat'].shape, c.keys())
            print (c['c_concat'].nonzero())
            #print (c['c_concat'][0, 0, :, :])

        if DDPM:
            samples_ddim = model.p_sample_loop(cond=c, shape=[1, 4, 48, 64], return_intermediates=False, verbose=True)
        else:
            samples_ddim, _ = sampler.sample(S=8,
                                         conditioning=c,
                                         batch_size=1,
                                         shape=[4, 48, 64],
                                         verbose=False)
        #                                 unconditional_guidance_scale=5.0,
        #                                 unconditional_conditioning=uc,
        #                                 eta=0)
        if DEBUG:
            print ('samples_ddim.shape', samples_ddim.shape)
            x_samples_ddim = samples_ddim[:, :3]
            # upsample to 512 x 384
            #x_samples_ddim = torch.nn.functional.interpolate(x_samples_ddim, size=(384, 512), mode='bilinear')
            # create a 512 x 384 image and paste the samples_ddim into the center
            x_samples_ddim = torch.zeros((1, 3, 384, 512))
            x_samples_ddim[:, :, 128:128+48, 160:160+64] = samples_ddim[:, :3]
        else:
            x_samples_ddim = model.decode_first_stage(samples_ddim)
        #x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
        #x_samples_ddim = model.decode_first_stage(x_samples_ddim)
        #x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
        
        return x_samples_ddim.squeeze(0).cpu().numpy()

# Global variables for model and device
#model = None
#device = None

def initialize_model(config_path, model_name):
    #global model, device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model_from_config(config_path, model_name, device)
    return model