Clone04 commited on
Commit
a8c7acd
·
verified ·
1 Parent(s): 6363636

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+ # Import ACGPN-specific modules (adjust based on actual repository structure)
8
+ # Note: You may need to copy relevant ACGPN code into the Space or modify imports
9
+ from models.acgpn import ACGPN # Hypothetical import; replace with actual model class
10
+ from utils.preprocessing import preprocess_image, parse_human # Hypothetical preprocessing utilities
11
+
12
+ # Set device to CPU
13
+ device = torch.device("cpu")
14
+
15
+ # Load pre-trained ACGPN model
16
+ def load_model():
17
+ model = ACGPN() # Initialize model (adjust parameters as per ACGPN docs)
18
+ checkpoint_path = "checkpoints/acgpn_checkpoint.pth" # Path to pre-trained weights
19
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
20
+ model.to(device)
21
+ model.eval()
22
+ return model
23
+
24
+ model = load_model()
25
+
26
+ # Function to process images and generate try-on
27
+ def virtual_try_on(person_image, cloth_image):
28
+ try:
29
+ # Convert Gradio inputs (PIL Images) to numpy arrays
30
+ person_img = np.array(person_image)
31
+ cloth_img = np.array(cloth_image)
32
+
33
+ # Preprocess images (resize, normalize, etc.)
34
+ person_processed, person_mask = preprocess_image(person_img, is_person=True)
35
+ cloth_processed = preprocess_image(cloth_img, is_person=False)
36
+
37
+ # Parse human pose and segmentation (using ACGPN utilities)
38
+ pose_map, parse_map = parse_human(person_processed)
39
+
40
+ # Convert to tensors
41
+ person_tensor = torch.from_numpy(person_processed).float().to(device)
42
+ cloth_tensor = torch.from_numpy(cloth_processed).float().to(device)
43
+ pose_tensor = torch.from_numpy(pose_map).float().to(device)
44
+ parse_tensor = torch.from_numpy(parse_map).float().to(device)
45
+
46
+ # Run inference
47
+ with torch.no_grad():
48
+ output = model(person_tensor, cloth_tensor, pose_tensor, parse_tensor)
49
+ output = output.cpu().numpy()
50
+
51
+ # Post-process output
52
+ output_img = (output * 255).astype(np.uint8)
53
+ output_img = Image.fromarray(output_img)
54
+
55
+ return output_img
56
+ except Exception as e:
57
+ return f"Error: {str(e)}"
58
+
59
+ # Gradio interface
60
+ iface = gr.Interface(
61
+ fn=virtual_try_on,
62
+ inputs=[
63
+ gr.Image(type="pil", label="Upload Person Image"),
64
+ gr.Image(type="pil", label="Upload Clothing Image"),
65
+ ],
66
+ outputs=gr.Image(type="pil", label="Try-On Result"),
67
+ title="ACGPN Virtual Try-On",
68
+ description="Upload a person image and a clothing image to see the virtual try-on result.",
69
+ )
70
+
71
+ # Launch the app
72
+ if __name__ == "__main__":
73
+ iface.launch(server_name="0.0.0.0", server_port=7860)