Spaces:
Runtime error
Runtime error
File size: 3,102 Bytes
3026a03 7970c81 3026a03 cfbf9d9 bde4ec6 3026a03 1729d9d 3026a03 d82c9de 3026a03 c520c64 3026a03 d82c9de dbf3e56 87d19ad bde4ec6 f1b5f5a bde4ec6 cd40774 87d19ad 3026a03 ad1b655 b6e6a40 9202e10 3026a03 ef2044c 3026a03 ef2044c 3026a03 ef2044c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import json
import os
import time
def load_model_from_config(config_path, model_name, device='cuda'):
# Load the config file
config = OmegaConf.load(config_path)
# Instantiate the model
model = instantiate_from_config(config.model)
# Download the model file from Hugging Face
model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors", token=os.getenv('HF_TOKEN'))
print(f"Loading model from {model_name}")
# Load the state dict
state_dict = torch.load(model_file, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
return model
def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_map=None):
sampler = DDIMSampler(model)
with torch.no_grad():
#u_dict = {'c_crossattn': "", 'c_concat': image_sequence}
#uc = model.get_learned_conditioning(u_dict)
#uc = model.enc_concat_seq(uc, u_dict, 'c_concat')
c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
c = model.get_learned_conditioning(c_dict)
c = model.enc_concat_seq(c, c_dict, 'c_concat')
if pos_map is not None:
print (pos_map.shape, c['c_concat'].shape)
c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)
print ('sleeping')
#time.sleep(120)
print ('finished sleeping')
samples_ddim = model.p_sample_loop(cond=c, shape=[1, 3, 256, 256], return_intermediates=False, verbose=True)
#samples_ddim, _ = sampler.sample(S=999,
# conditioning=c,
# batch_size=1,
# shape=[3, 64, 64],
# verbose=False,
# unconditional_guidance_scale=5.0,
# unconditional_conditioning=uc,
# eta=0)
x_samples_ddim = samples_ddim #model.decode_first_stage(samples_ddim)
#x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
#x_samples_ddim = model.decode_first_stage(x_samples_ddim)
#x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
return x_samples_ddim.squeeze(0).cpu().numpy()
# Global variables for model and device
#model = None
#device = None
def initialize_model(config_path, model_name):
#global model, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model_from_config(config_path, model_name, device)
return model |