Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
# Import ACGPN-specific modules (adjust based on actual repository structure)
|
8 |
+
# Note: You may need to copy relevant ACGPN code into the Space or modify imports
|
9 |
+
from models.acgpn import ACGPN # Hypothetical import; replace with actual model class
|
10 |
+
from utils.preprocessing import preprocess_image, parse_human # Hypothetical preprocessing utilities
|
11 |
+
|
12 |
+
# Set device to CPU
|
13 |
+
device = torch.device("cpu")
|
14 |
+
|
15 |
+
# Load pre-trained ACGPN model
|
16 |
+
def load_model():
|
17 |
+
model = ACGPN() # Initialize model (adjust parameters as per ACGPN docs)
|
18 |
+
checkpoint_path = "checkpoints/acgpn_checkpoint.pth" # Path to pre-trained weights
|
19 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
20 |
+
model.to(device)
|
21 |
+
model.eval()
|
22 |
+
return model
|
23 |
+
|
24 |
+
model = load_model()
|
25 |
+
|
26 |
+
# Function to process images and generate try-on
|
27 |
+
def virtual_try_on(person_image, cloth_image):
|
28 |
+
try:
|
29 |
+
# Convert Gradio inputs (PIL Images) to numpy arrays
|
30 |
+
person_img = np.array(person_image)
|
31 |
+
cloth_img = np.array(cloth_image)
|
32 |
+
|
33 |
+
# Preprocess images (resize, normalize, etc.)
|
34 |
+
person_processed, person_mask = preprocess_image(person_img, is_person=True)
|
35 |
+
cloth_processed = preprocess_image(cloth_img, is_person=False)
|
36 |
+
|
37 |
+
# Parse human pose and segmentation (using ACGPN utilities)
|
38 |
+
pose_map, parse_map = parse_human(person_processed)
|
39 |
+
|
40 |
+
# Convert to tensors
|
41 |
+
person_tensor = torch.from_numpy(person_processed).float().to(device)
|
42 |
+
cloth_tensor = torch.from_numpy(cloth_processed).float().to(device)
|
43 |
+
pose_tensor = torch.from_numpy(pose_map).float().to(device)
|
44 |
+
parse_tensor = torch.from_numpy(parse_map).float().to(device)
|
45 |
+
|
46 |
+
# Run inference
|
47 |
+
with torch.no_grad():
|
48 |
+
output = model(person_tensor, cloth_tensor, pose_tensor, parse_tensor)
|
49 |
+
output = output.cpu().numpy()
|
50 |
+
|
51 |
+
# Post-process output
|
52 |
+
output_img = (output * 255).astype(np.uint8)
|
53 |
+
output_img = Image.fromarray(output_img)
|
54 |
+
|
55 |
+
return output_img
|
56 |
+
except Exception as e:
|
57 |
+
return f"Error: {str(e)}"
|
58 |
+
|
59 |
+
# Gradio interface
|
60 |
+
iface = gr.Interface(
|
61 |
+
fn=virtual_try_on,
|
62 |
+
inputs=[
|
63 |
+
gr.Image(type="pil", label="Upload Person Image"),
|
64 |
+
gr.Image(type="pil", label="Upload Clothing Image"),
|
65 |
+
],
|
66 |
+
outputs=gr.Image(type="pil", label="Try-On Result"),
|
67 |
+
title="ACGPN Virtual Try-On",
|
68 |
+
description="Upload a person image and a clothing image to see the virtual try-on result.",
|
69 |
+
)
|
70 |
+
|
71 |
+
# Launch the app
|
72 |
+
if __name__ == "__main__":
|
73 |
+
iface.launch(server_name="0.0.0.0", server_port=7860)
|