File size: 5,671 Bytes
3026a03
 
7970c81
 
3026a03
 
 
 
cfbf9d9
bde4ec6
18d5c14
3026a03
e77b83d
3026a03
 
 
 
 
 
 
837289b
 
 
 
 
8043e65
3026a03
 
 
 
 
b07d75e
3026a03
 
 
c520c64
 
 
3026a03
aab7be4
18d5c14
fbec427
18d5c14
 
 
83455bf
24d0256
 
 
 
 
f942748
9d0127f
a9d9852
9d0127f
 
 
 
 
6ee36ca
a9d9852
b07d75e
 
 
 
6eec349
32367a7
94b146f
87d19ad
bde4ec6
f1b5f5a
bde4ec6
abac87c
d8a95e7
66279e3
09eba18
 
32367a7
7f1aea2
32367a7
 
09eba18
a10a91e
4bb6f21
a10a91e
f8d24b9
be95217
 
4bb6f21
be95217
87d19ad
 
 
c127bb8
 
6eec349
09eba18
 
f660fd9
b189664
6cdc4a9
b189664
 
09eba18
c127bb8
18d5c14
 
 
 
43bc3c7
9d0127f
8c27ef5
18d5c14
c127bb8
b6e6a40
 
9202e10
 
3026a03
9d0127f
3026a03
 
ef2044c
 
3026a03
 
ef2044c
3026a03
ef2044c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import json
import os
import time
DEBUG = False

def load_model_from_config(config_path, model_name, device='cuda', load=True):
    # Load the config file
    config = OmegaConf.load(config_path)
    
    # Instantiate the model
    model = instantiate_from_config(config.model)
    
    # Download the model file from Hugging Face
    if load:
        model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors", token=os.getenv('HF_TOKEN'))
        print(f"Loading model from {model_name}")
        # Load the state dict
        state_dict = torch.load(model_file, map_location='cpu')
        model.load_state_dict(state_dict, strict=False)
    
    model.to(device)
    model.eval()
    return model

def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_maps=None, leftclick_maps=None):
    sampler = DDIMSampler(model)
    
    with torch.no_grad():
        #u_dict = {'c_crossattn': "", 'c_concat': image_sequence}
        #uc = model.get_learned_conditioning(u_dict)
        #uc = model.enc_concat_seq(uc, u_dict, 'c_concat')
        
        #c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
        
        model.eval()
        #c = model.get_learned_conditioning(c_dict)
        #print (c['c_crossattn'].shape)
        #print (c['c_crossattn'][0])
        print (prompt)
        # reshape(B, L * C, H, W)
        #height, width, channels = image_sequence.shape
        # use einsum to reshape
        image_sequence = torch.einsum('hwc->chw', image_sequence).unsqueeze(0)
        c = {'c_concat': image_sequence}
        print (image_sequence.shape, c['c_concat'].shape)
        #c = model.enc_concat_seq(c, c_dict, 'c_concat')
        # Zero out the corresponding subtensors in c_concat for padding images
        #padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(0)
        #print (padding_mask)
        #padding_mask = padding_mask.repeat(1, 4)  # Repeat mask 4 times for each projected channel
        #print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
        #c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1))  # Zero out the corresponding features
        
        
        if pos_maps is not None:
            pos_map = pos_maps[0]
            leftclick_map = torch.cat(leftclick_maps, dim=0)
            print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
            if False and DEBUG:
                c['c_concat'] = c['c_concat']*0
            c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)

        print ('sleeping')
        #time.sleep(120)
        print ('finished sleeping')
        DDPM = False
        DDPM = True
        DDPM = False

        if DEBUG:
            #c['c_concat'] = c['c_concat']*0
            print ('utils prompt', prompt, c['c_concat'].shape, c.keys())
            print (c['c_concat'].nonzero())
            #print (c['c_concat'][0, 0, :, :])

        if DDPM:
            samples_ddim = model.p_sample_loop(cond=c, shape=[1, 4, 48, 64], return_intermediates=False, verbose=True)
        else:
            samples_ddim, _ = sampler.sample(S=16,
                                         conditioning=c,
                                         batch_size=1,
                                         shape=[4, 48, 64],
                                         verbose=False)
        #                                 unconditional_guidance_scale=5.0,
        #                                 unconditional_conditioning=uc,
        #                                 eta=0)

        print ('dfsf1')
        if False and DEBUG:
            print ('samples_ddim.shape', samples_ddim.shape)
            x_samples_ddim = samples_ddim[:, :3]
            # upsample to 512 x 384
            x_samples_ddim = torch.nn.functional.interpolate(x_samples_ddim, size=(384, 512), mode='bilinear')
            # create a 512 x 384 image and paste the samples_ddim into the center
            #x_samples_ddim = torch.zeros((1, 3, 384, 512))
            #x_samples_ddim[:, :, 128:128+48, 160:160+64] = samples_ddim[:, :3]
        else:
            print ('dfsf2')
            data_mean = -0.54
            data_std = 6.78
            data_min = -27.681446075439453
            data_max = 30.854148864746094
            x_samples_ddim = samples_ddim
            x_samples_ddim_feedback = x_samples_ddim
            x_samples_ddim = x_samples_ddim * data_std + data_mean
            x_samples_ddim = model.decode_first_stage(x_samples_ddim)
        print ('dfsf3')
        #x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
        #x_samples_ddim = model.decode_first_stage(x_samples_ddim)
        #x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
        
        return x_samples_ddim.squeeze(0).cpu().numpy(), x_samples_ddim_feedback.squeeze(0)

# Global variables for model and device
#model = None
#device = None

def initialize_model(config_path, model_name):
    #global model, device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model_from_config(config_path, model_name, device)
    return model