import gradio as gr import torch import cv2 import numpy as np from PIL import Image import os # Import ACGPN-specific modules (adjust based on actual repository structure) # Note: You may need to copy relevant ACGPN code into the Space or modify imports from models.acgpn import ACGPN # Hypothetical import; replace with actual model class from utils.preprocessing import preprocess_image, parse_human # Hypothetical preprocessing utilities # Set device to CPU device = torch.device("cpu") # Load pre-trained ACGPN model def load_model(): model = ACGPN() # Initialize model (adjust parameters as per ACGPN docs) checkpoint_path = "checkpoints/acgpn_checkpoint.pth" # Path to pre-trained weights model.load_state_dict(torch.load(checkpoint_path, map_location=device)) model.to(device) model.eval() return model model = load_model() # Function to process images and generate try-on def virtual_try_on(person_image, cloth_image): try: # Convert Gradio inputs (PIL Images) to numpy arrays person_img = np.array(person_image) cloth_img = np.array(cloth_image) # Preprocess images (resize, normalize, etc.) person_processed, person_mask = preprocess_image(person_img, is_person=True) cloth_processed = preprocess_image(cloth_img, is_person=False) # Parse human pose and segmentation (using ACGPN utilities) pose_map, parse_map = parse_human(person_processed) # Convert to tensors person_tensor = torch.from_numpy(person_processed).float().to(device) cloth_tensor = torch.from_numpy(cloth_processed).float().to(device) pose_tensor = torch.from_numpy(pose_map).float().to(device) parse_tensor = torch.from_numpy(parse_map).float().to(device) # Run inference with torch.no_grad(): output = model(person_tensor, cloth_tensor, pose_tensor, parse_tensor) output = output.cpu().numpy() # Post-process output output_img = (output * 255).astype(np.uint8) output_img = Image.fromarray(output_img) return output_img except Exception as e: return f"Error: {str(e)}" # Gradio interface iface = gr.Interface( fn=virtual_try_on, inputs=[ gr.Image(type="pil", label="Upload Person Image"), gr.Image(type="pil", label="Upload Clothing Image"), ], outputs=gr.Image(type="pil", label="Try-On Result"), title="ACGPN Virtual Try-On", description="Upload a person image and a clothing image to see the virtual try-on result.", ) # Launch the app if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)