File size: 4,852 Bytes
d2bc90c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import os
from PIL import Image
import requests
import base64
import io
from dotenv import load_dotenv

load_dotenv()

example_path = os.path.join(os.path.dirname(__file__), 'examples')

def image_to_base64(image_path):  # Remove 'self'
    """Convert image file to base64 string"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode()

def base64_to_image(base64_str, output_path):  # Remove 'self'
    """Convert base64 string to image file"""
    image_data = base64.b64decode(base64_str)
    image = Image.open(io.BytesIO(image_data))
    image.save(output_path)
    return image

def run_viton(model_image_path, garment_image_path, 
                n_steps=20, image_scale=2.0, seed=-1):
    try:
        api_url = os.environ.get("SERVER_URL")
        print(f"Using API URL: {api_url}")  # Add this to debug
        
        # Convert images to base64 (remove 'self.')
        model_b64 = image_to_base64(model_image_path)
        garment_b64 = image_to_base64(garment_image_path)

        # Prepare request
        request_data = {
            "model_image_base64": model_b64,
            "garment_image_base64": garment_b64,
            "n_samples": 1,
            "n_steps": n_steps,
            "image_scale": image_scale,
            "seed": seed
        }
                
        # Send request
        response = requests.post(f"{api_url}/viton", 
                                json=request_data, 
                                timeout=300)
        
        print(f"Request sent to {api_url}/viton")
        print(f"Response status code: {response.status_code}")
        
        if response.status_code == 200:
            result = response.json()
            if result.get("error"):
                print(f"Error: {result['error']}")
                return []
            
            generated_images = []
            for i, img_b64 in enumerate(result.get("images_base64", [])):
                output_path = f"ootd_output_{i}.png"
                img = base64_to_image(img_b64, output_path)  # Remove 'self.'
                generated_images.append(img)
            
            print(f"Successfully generated {len(generated_images)} images")
            return generated_images
        else:
            print(f"Request failed with status code: {response.status_code}")
            return []  # Fix: was missing 'return'
            
    except Exception as e:
        print(f"Exception occurred: {str(e)}")  # Add this
        return []  # Fix: should return list, not dict for gallery

block = gr.Blocks().queue()
default_model = os.path.join(example_path, 'model/model_8.png')
default_garment = os.path.join(example_path, 'garment/00055_00.jpg')
with block:
    with gr.Row():
        gr.Markdown("# Virtual Try-On")
    with gr.Row():
        with gr.Column():
            vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384, value=default_model)
            example = gr.Examples(
                inputs=vton_img,
                examples_per_page=5,
                examples=[
                    os.path.join(example_path, 'model/model_8.png'),
                    os.path.join(example_path, 'model/model_2.png'),
                    os.path.join(example_path, 'model/model_7.png'),
                    os.path.join(example_path, 'model/model_4.png'),
                    os.path.join(example_path, 'model/model_5.png'),
                ])
        with gr.Column():
            garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384, value=default_garment)
            example = gr.Examples(
                inputs=garm_img,
                examples_per_page=5,
                examples=[
                    os.path.join(example_path, 'garment/00055_00.jpg'),
                    os.path.join(example_path, 'garment/07764_00.jpg'),
                    os.path.join(example_path, 'garment/03032_00.jpg'),
                    os.path.join(example_path, 'garment/048554_1.jpg'),
                    os.path.join(example_path, 'garment/049805_1.jpg'),
                ])
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)   
    with gr.Column():
        run_button = gr.Button(value="Run")
        n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
        image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
        seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
        
    ips = [vton_img, garm_img, n_steps, image_scale, seed]
    run_button.click(fn=run_viton, inputs=ips, outputs=result_gallery)

block.launch(server_name='0.0.0.0', server_port=7865, mcp_server=True)