Spaces:
Runtime error
Runtime error
File size: 3,883 Bytes
3026a03 7970c81 3026a03 cfbf9d9 bde4ec6 3026a03 1729d9d 3026a03 b07d75e 3026a03 c520c64 3026a03 a9d9852 77ade98 a9d9852 798ede0 a9d9852 b07d75e 94b146f 87d19ad bde4ec6 f1b5f5a bde4ec6 0abd296 a10a91e be95217 efb10b4 be95217 87d19ad 3026a03 d699927 b6e6a40 9202e10 3026a03 ef2044c 3026a03 ef2044c 3026a03 ef2044c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import json
import os
import time
def load_model_from_config(config_path, model_name, device='cuda'):
# Load the config file
config = OmegaConf.load(config_path)
# Instantiate the model
model = instantiate_from_config(config.model)
# Download the model file from Hugging Face
model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors", token=os.getenv('HF_TOKEN'))
print(f"Loading model from {model_name}")
# Load the state dict
state_dict = torch.load(model_file, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
return model
def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_maps=None, leftclick_maps=None):
sampler = DDIMSampler(model)
with torch.no_grad():
#u_dict = {'c_crossattn': "", 'c_concat': image_sequence}
#uc = model.get_learned_conditioning(u_dict)
#uc = model.enc_concat_seq(uc, u_dict, 'c_concat')
c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
c = model.get_learned_conditioning(c_dict)
c = model.enc_concat_seq(c, c_dict, 'c_concat')
# Zero out the corresponding subtensors in c_concat for padding images
padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(0)
print (padding_mask)
padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
if pos_maps is not None:
pos_map = pos_maps[0]
leftclick_map = torch.cat(leftclick_maps, dim=0)
print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)
print ('sleeping')
#time.sleep(120)
print ('finished sleeping')
DDPM = False
if DDPM:
samples_ddim = model.p_sample_loop(cond=c, shape=[1, 4, 64, 64], return_intermediates=False, verbose=True)
else:
samples_ddim, _ = sampler.sample(S=8,
conditioning=c,
batch_size=1,
shape=[4, 64, 64],
verbose=False)
# unconditional_guidance_scale=5.0,
# unconditional_conditioning=uc,
# eta=0)
x_samples_ddim = model.decode_first_stage(samples_ddim)
#x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
#x_samples_ddim = model.decode_first_stage(x_samples_ddim)
#x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
return x_samples_ddim.squeeze(0).cpu().numpy()
# Global variables for model and device
#model = None
#device = None
def initialize_model(config_path, model_name):
#global model, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model_from_config(config_path, model_name, device)
return model |