|
import gradio as gr |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
|
|
|
|
from models.acgpn import ACGPN |
|
from utils.preprocessing import preprocess_image, parse_human |
|
|
|
|
|
device = torch.device("cpu") |
|
|
|
|
|
def load_model(): |
|
model = ACGPN() |
|
checkpoint_path = "checkpoints/acgpn_checkpoint.pth" |
|
model.load_state_dict(torch.load(checkpoint_path, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
model = load_model() |
|
|
|
|
|
def virtual_try_on(person_image, cloth_image): |
|
try: |
|
|
|
person_img = np.array(person_image) |
|
cloth_img = np.array(cloth_image) |
|
|
|
|
|
person_processed, person_mask = preprocess_image(person_img, is_person=True) |
|
cloth_processed = preprocess_image(cloth_img, is_person=False) |
|
|
|
|
|
pose_map, parse_map = parse_human(person_processed) |
|
|
|
|
|
person_tensor = torch.from_numpy(person_processed).float().to(device) |
|
cloth_tensor = torch.from_numpy(cloth_processed).float().to(device) |
|
pose_tensor = torch.from_numpy(pose_map).float().to(device) |
|
parse_tensor = torch.from_numpy(parse_map).float().to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(person_tensor, cloth_tensor, pose_tensor, parse_tensor) |
|
output = output.cpu().numpy() |
|
|
|
|
|
output_img = (output * 255).astype(np.uint8) |
|
output_img = Image.fromarray(output_img) |
|
|
|
return output_img |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=virtual_try_on, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Person Image"), |
|
gr.Image(type="pil", label="Upload Clothing Image"), |
|
], |
|
outputs=gr.Image(type="pil", label="Try-On Result"), |
|
title="ACGPN Virtual Try-On", |
|
description="Upload a person image and a clothing image to see the virtual try-on result.", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch(server_name="0.0.0.0", server_port=7860) |