Spaces:
Sleeping
Sleeping
| import os | |
| # External libraries | |
| import torch | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from diffusers.utils import check_min_version | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| # Custom imports | |
| from src.datasets.dresscode import DressCodeDataset | |
| from src.datasets.vitonhd import VitonHDDataset | |
| from src.mgd_pipelines.mgd_pipe import MGDPipe | |
| from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled | |
| from src.utils.image_from_pipe import generate_images_from_mgd_pipe | |
| from src.utils.set_seeds import set_seed | |
| # Ensure the minimum version of diffusers is installed | |
| check_min_version("0.10.0.dev0") | |
| logger = get_logger(__name__, log_level="INFO") | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| os.environ["WANDB_START_METHOD"] = "thread" | |
| def main(args): | |
| # Initialize Accelerator | |
| accelerator = Accelerator(mixed_precision=args.get("mixed_precision", "fp16")) | |
| device = accelerator.device | |
| # Set the training seed | |
| if args.get("seed") is not None: | |
| set_seed(args["seed"]) | |
| # Load scheduler, tokenizer, and models | |
| val_scheduler = DDIMScheduler.from_pretrained(args["pretrained_model_name_or_path"], subfolder="scheduler") | |
| val_scheduler.set_timesteps(50, device=device) | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| args["pretrained_model_name_or_path"], subfolder="tokenizer", revision=args.get("revision", None) | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| args["pretrained_model_name_or_path"], subfolder="text_encoder", revision=args.get("revision", None) | |
| ) | |
| vae = AutoencoderKL.from_pretrained(args["pretrained_model_name_or_path"], subfolder="vae", revision=args.get("revision", None)) | |
| # Load UNet | |
| unet = torch.hub.load( | |
| repo_or_dir="aimagelab/multimodal-garment-designer", | |
| source="github", | |
| model="mgd", | |
| pretrained=True, | |
| ) | |
| # Freeze models | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| # Enable memory efficient attention if requested | |
| if args.get("enable_xformers_memory_efficient_attention", False): | |
| if is_xformers_available(): | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Install it to enable memory-efficient attention.") | |
| # Set dataset category | |
| category = [args.get("category", "dresses")] | |
| # Load dataset | |
| if args["dataset"] == "dresscode": | |
| test_dataset = DressCodeDataset( | |
| dataroot_path=args["dataset_path"], | |
| phase="test", | |
| order=args.get("test_order", 0), | |
| radius=5, | |
| sketch_threshold_range=(20, 20), | |
| tokenizer=tokenizer, | |
| category=category, | |
| size=(512, 384), | |
| ) | |
| elif args["dataset"] == "vitonhd": | |
| test_dataset = VitonHDDataset( | |
| dataroot_path=args["dataset_path"], | |
| phase="test", | |
| order=args.get("test_order", 0), | |
| sketch_threshold_range=(20, 20), | |
| radius=5, | |
| tokenizer=tokenizer, | |
| size=(512, 384), | |
| ) | |
| else: | |
| raise NotImplementedError(f"Dataset {args['dataset']} is not supported.") | |
| # Prepare dataloader | |
| test_dataloader = torch.utils.data.DataLoader( | |
| test_dataset, | |
| shuffle=False, | |
| batch_size=args.get("batch_size", 1), | |
| num_workers=args.get("num_workers_test", 4), | |
| ) | |
| # Cast models to appropriate precision | |
| weight_dtype = torch.float32 if args.get("mixed_precision") != "fp16" else torch.float16 | |
| text_encoder.to(device, dtype=weight_dtype) | |
| vae.to(device, dtype=weight_dtype) | |
| unet.eval() | |
| # Select pipeline | |
| with torch.inference_mode(): | |
| pipeline_class = MGDPipeDisentangled if args.get("disentagle", False) else MGDPipe | |
| val_pipe = pipeline_class( | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| unet=unet.to(vae.dtype), | |
| tokenizer=tokenizer, | |
| scheduler=val_scheduler, | |
| ).to(device) | |
| val_pipe.enable_attention_slicing() | |
| # Prepare dataloader with accelerator | |
| test_dataloader = accelerator.prepare(test_dataloader) | |
| # Generate images | |
| output_path = os.path.join(args["output_dir"], args.get("save_name", "generated_image.png")) | |
| generate_images_from_mgd_pipe( | |
| test_order=args.get("test_order", 0), | |
| pipe=val_pipe, | |
| test_dataloader=test_dataloader, | |
| save_name=args.get("save_name", "generated_image"), | |
| dataset=args["dataset"], | |
| output_dir=args["output_dir"], | |
| guidance_scale=args.get("guidance_scale", 7.5), | |
| guidance_scale_pose=args.get("guidance_scale_pose", 0.5), | |
| guidance_scale_sketch=args.get("guidance_scale_sketch", 7.5), | |
| sketch_cond_rate=args.get("sketch_cond_rate", 1.0), | |
| start_cond_rate=args.get("start_cond_rate", 0.0), | |
| no_pose=False, | |
| disentagle=args.get("disentagle", False), | |
| seed=args.get("seed", None), | |
| ) | |
| # Return the output image path for verification | |
| return output_path | |
| if __name__ == "__main__": | |
| # Example usage for debugging | |
| example_args = { | |
| "pretrained_model_name_or_path": "./models", | |
| "dataset": "dresscode", | |
| "dataset_path": "./datasets/dresscode", | |
| "output_dir": "./outputs", | |
| "guidance_scale": 7.5, | |
| "guidance_scale_sketch": 7.5, | |
| "mixed_precision": "fp16", | |
| "batch_size": 1, | |
| "seed": 42, | |
| } | |
| output_image = main(example_args) | |
| print(f"Image generated at: {output_image}") | |