Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from typing import List | |
| import torch | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from PIL import Image | |
| from torchvision import transforms | |
| from lbm.models.embedders import ( | |
| ConditionerWrapper, | |
| LatentsConcatEmbedder, | |
| LatentsConcatEmbedderConfig, | |
| ) | |
| from lbm.models.lbm import LBMConfig, LBMModel | |
| from lbm.models.unets import DiffusersUNet2DCondWrapper | |
| from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig | |
| def get_model_from_config( | |
| backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0", | |
| vae_num_channels: int = 4, | |
| unet_input_channels: int = 4, | |
| timestep_sampling: str = "log_normal", | |
| selected_timesteps: List[float] = None, | |
| prob: List[float] = None, | |
| conditioning_images_keys: List[str] = [], | |
| conditioning_masks_keys: List[str] = ["mask"], | |
| source_key: str = "source_image", | |
| target_key: str = "source_image_paste", | |
| bridge_noise_sigma: float = 0.0, | |
| ): | |
| conditioners = [] | |
| denoiser = DiffusersUNet2DCondWrapper( | |
| in_channels=unet_input_channels, # Add downsampled_image | |
| out_channels=vae_num_channels, | |
| center_input_sample=False, | |
| flip_sin_to_cos=True, | |
| freq_shift=0, | |
| down_block_types=[ | |
| "DownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| ], | |
| mid_block_type="UNetMidBlock2DCrossAttn", | |
| up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], | |
| only_cross_attention=False, | |
| block_out_channels=[320, 640, 1280], | |
| layers_per_block=2, | |
| downsample_padding=1, | |
| mid_block_scale_factor=1, | |
| dropout=0.0, | |
| act_fn="silu", | |
| norm_num_groups=32, | |
| norm_eps=1e-05, | |
| cross_attention_dim=[320, 640, 1280], | |
| transformer_layers_per_block=[1, 2, 10], | |
| reverse_transformer_layers_per_block=None, | |
| encoder_hid_dim=None, | |
| encoder_hid_dim_type=None, | |
| attention_head_dim=[5, 10, 20], | |
| num_attention_heads=None, | |
| dual_cross_attention=False, | |
| use_linear_projection=True, | |
| class_embed_type=None, | |
| addition_embed_type=None, | |
| addition_time_embed_dim=None, | |
| num_class_embeds=None, | |
| upcast_attention=None, | |
| resnet_time_scale_shift="default", | |
| resnet_skip_time_act=False, | |
| resnet_out_scale_factor=1.0, | |
| time_embedding_type="positional", | |
| time_embedding_dim=None, | |
| time_embedding_act_fn=None, | |
| timestep_post_act=None, | |
| time_cond_proj_dim=None, | |
| conv_in_kernel=3, | |
| conv_out_kernel=3, | |
| projection_class_embeddings_input_dim=None, | |
| attention_type="default", | |
| class_embeddings_concat=False, | |
| mid_block_only_cross_attention=None, | |
| cross_attention_norm=None, | |
| addition_embed_type_num_heads=64, | |
| ).to(torch.bfloat16) | |
| if conditioning_images_keys != [] or conditioning_masks_keys != []: | |
| latents_concat_embedder_config = LatentsConcatEmbedderConfig( | |
| image_keys=conditioning_images_keys, | |
| mask_keys=conditioning_masks_keys, | |
| ) | |
| latent_concat_embedder = LatentsConcatEmbedder(latents_concat_embedder_config) | |
| latent_concat_embedder.freeze() | |
| conditioners.append(latent_concat_embedder) | |
| # Wrap conditioners and set to device | |
| conditioner = ConditionerWrapper( | |
| conditioners=conditioners, | |
| ) | |
| ## VAE ## | |
| # Get VAE model | |
| vae_config = AutoencoderKLDiffusersConfig( | |
| version=backbone_signature, | |
| subfolder="vae", | |
| tiling_size=(128, 128), | |
| ) | |
| vae = AutoencoderKLDiffusers(vae_config).to(torch.bfloat16) | |
| vae.freeze() | |
| vae.to(torch.bfloat16) | |
| ## Diffusion Model ## | |
| # Get diffusion model | |
| config = LBMConfig( | |
| source_key=source_key, | |
| target_key=target_key, | |
| timestep_sampling=timestep_sampling, | |
| selected_timesteps=selected_timesteps, | |
| prob=prob, | |
| bridge_noise_sigma=bridge_noise_sigma, | |
| ) | |
| sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| backbone_signature, | |
| subfolder="scheduler", | |
| ) | |
| model = LBMModel( | |
| config, | |
| denoiser=denoiser, | |
| sampling_noise_scheduler=sampling_noise_scheduler, | |
| vae=vae, | |
| conditioner=conditioner, | |
| ).to(torch.bfloat16) | |
| return model | |
| def extract_object(birefnet, img): | |
| # Data settings | |
| image_size = (1024, 1024) | |
| transform_image = transforms.Compose( | |
| [ | |
| transforms.Resize(image_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| image = img | |
| input_images = transform_image(image).unsqueeze(0).cuda() | |
| # Prediction | |
| with torch.no_grad(): | |
| preds = birefnet(input_images)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| pred_pil = transforms.ToPILImage()(pred) | |
| mask = pred_pil.resize(image.size) | |
| image = Image.composite(image, Image.new("RGB", image.size, (127, 127, 127)), mask) | |
| return image, mask | |
| def resize_and_center_crop(image, target_width, target_height): | |
| original_width, original_height = image.size | |
| scale_factor = max(target_width / original_width, target_height / original_height) | |
| resized_width = int(round(original_width * scale_factor)) | |
| resized_height = int(round(original_height * scale_factor)) | |
| resized_image = image.resize((resized_width, resized_height), Image.LANCZOS) | |
| left = (resized_width - target_width) / 2 | |
| top = (resized_height - target_height) / 2 | |
| right = (resized_width + target_width) / 2 | |
| bottom = (resized_height + target_height) / 2 | |
| cropped_image = resized_image.crop((left, top, right, bottom)) | |
| return cropped_image | |