File size: 3,102 Bytes
3026a03
 
7970c81
 
3026a03
 
 
 
cfbf9d9
bde4ec6
3026a03
 
 
 
 
 
 
 
 
1729d9d
3026a03
 
 
 
 
 
 
 
 
 
d82c9de
3026a03
 
 
c520c64
 
 
3026a03
 
 
 
d82c9de
 
dbf3e56
87d19ad
bde4ec6
f1b5f5a
bde4ec6
cd40774
87d19ad
 
 
 
 
 
 
 
3026a03
ad1b655
b6e6a40
 
9202e10
 
3026a03
 
 
 
ef2044c
 
3026a03
 
ef2044c
3026a03
ef2044c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import json
import os
import time

def load_model_from_config(config_path, model_name, device='cuda'):
    # Load the config file
    config = OmegaConf.load(config_path)
    
    # Instantiate the model
    model = instantiate_from_config(config.model)
    
    # Download the model file from Hugging Face
    model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors", token=os.getenv('HF_TOKEN'))
    
    print(f"Loading model from {model_name}")
    # Load the state dict
    state_dict = torch.load(model_file, map_location='cpu')
    model.load_state_dict(state_dict, strict=False)
    
    model.to(device)
    model.eval()
    return model

def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_map=None):
    sampler = DDIMSampler(model)
    
    with torch.no_grad():
        #u_dict = {'c_crossattn': "", 'c_concat': image_sequence}
        #uc = model.get_learned_conditioning(u_dict)
        #uc = model.enc_concat_seq(uc, u_dict, 'c_concat')
        
        c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
        c = model.get_learned_conditioning(c_dict)
        c = model.enc_concat_seq(c, c_dict, 'c_concat')
        if pos_map is not None:
            print (pos_map.shape, c['c_concat'].shape)
            c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)

        print ('sleeping')
        #time.sleep(120)
        print ('finished sleeping')
        samples_ddim = model.p_sample_loop(cond=c, shape=[1, 3, 256, 256], return_intermediates=False, verbose=True)
        #samples_ddim, _ = sampler.sample(S=999,
        #                                 conditioning=c,
        #                                 batch_size=1,
        #                                 shape=[3, 64, 64],
        #                                 verbose=False,
        #                                 unconditional_guidance_scale=5.0,
        #                                 unconditional_conditioning=uc,
        #                                 eta=0)
        
        x_samples_ddim = samples_ddim #model.decode_first_stage(samples_ddim)
        #x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
        #x_samples_ddim = model.decode_first_stage(x_samples_ddim)
        #x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
        
        return x_samples_ddim.squeeze(0).cpu().numpy()

# Global variables for model and device
#model = None
#device = None

def initialize_model(config_path, model_name):
    #global model, device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model_from_config(config_path, model_name, device)
    return model