yuntian-deng commited on
Commit
3026a03
·
1 Parent(s): a677593

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +62 -0
utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from omegaconf import OmegaConf
3
+ from ldm.util import instantiate_from_config
4
+ from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
5
+ import numpy as np
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
+ import json
9
+
10
+ def load_model_from_config(config_path, model_name, device='cuda'):
11
+ # Load the config file
12
+ config = OmegaConf.load(config_path)
13
+
14
+ # Instantiate the model
15
+ model = instantiate_from_config(config.model)
16
+
17
+ # Download the model file from Hugging Face
18
+ model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors")
19
+
20
+ print(f"Loading model from {model_name}")
21
+ # Load the state dict
22
+ state_dict = torch.load(model_file, map_location='cpu')
23
+ model.load_state_dict(state_dict, strict=False)
24
+
25
+ model.to(device)
26
+ model.eval()
27
+ return model
28
+
29
+ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor):
30
+ sampler = DDIMSampler(model)
31
+
32
+ with torch.no_grad():
33
+ u_dict = {'c_crossattn': "", 'c_concat': image_sequence}
34
+ uc = model.get_learned_conditioning(u_dict)
35
+ uc = model.enc_concat_seq(uc, u_dict, 'c_concat')
36
+
37
+ c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
38
+ c = model.get_learned_conditioning(c_dict)
39
+ c = model.enc_concat_seq(c, c_dict, 'c_concat')
40
+
41
+ samples_ddim, _ = sampler.sample(S=200,
42
+ conditioning=c,
43
+ batch_size=1,
44
+ shape=[3, 64, 64],
45
+ verbose=False,
46
+ unconditional_guidance_scale=5.0,
47
+ unconditional_conditioning=uc,
48
+ eta=0)
49
+
50
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
51
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
52
+
53
+ return x_samples_ddim.squeeze(0).cpu().numpy()
54
+
55
+ # Global variables for model and device
56
+ model = None
57
+ device = None
58
+
59
+ def initialize_model(config_path, model_name):
60
+ global model, device
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ model = load_model_from_config(config_path, model_name, device)