File size: 4,852 Bytes
d2bc90c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
import os
from PIL import Image
import requests
import base64
import io
from dotenv import load_dotenv
load_dotenv()
example_path = os.path.join(os.path.dirname(__file__), 'examples')
def image_to_base64(image_path): # Remove 'self'
"""Convert image file to base64 string"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode()
def base64_to_image(base64_str, output_path): # Remove 'self'
"""Convert base64 string to image file"""
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data))
image.save(output_path)
return image
def run_viton(model_image_path, garment_image_path,
n_steps=20, image_scale=2.0, seed=-1):
try:
api_url = os.environ.get("SERVER_URL")
print(f"Using API URL: {api_url}") # Add this to debug
# Convert images to base64 (remove 'self.')
model_b64 = image_to_base64(model_image_path)
garment_b64 = image_to_base64(garment_image_path)
# Prepare request
request_data = {
"model_image_base64": model_b64,
"garment_image_base64": garment_b64,
"n_samples": 1,
"n_steps": n_steps,
"image_scale": image_scale,
"seed": seed
}
# Send request
response = requests.post(f"{api_url}/viton",
json=request_data,
timeout=300)
print(f"Request sent to {api_url}/viton")
print(f"Response status code: {response.status_code}")
if response.status_code == 200:
result = response.json()
if result.get("error"):
print(f"Error: {result['error']}")
return []
generated_images = []
for i, img_b64 in enumerate(result.get("images_base64", [])):
output_path = f"ootd_output_{i}.png"
img = base64_to_image(img_b64, output_path) # Remove 'self.'
generated_images.append(img)
print(f"Successfully generated {len(generated_images)} images")
return generated_images
else:
print(f"Request failed with status code: {response.status_code}")
return [] # Fix: was missing 'return'
except Exception as e:
print(f"Exception occurred: {str(e)}") # Add this
return [] # Fix: should return list, not dict for gallery
block = gr.Blocks().queue()
default_model = os.path.join(example_path, 'model/model_8.png')
default_garment = os.path.join(example_path, 'garment/00055_00.jpg')
with block:
with gr.Row():
gr.Markdown("# Virtual Try-On")
with gr.Row():
with gr.Column():
vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384, value=default_model)
example = gr.Examples(
inputs=vton_img,
examples_per_page=5,
examples=[
os.path.join(example_path, 'model/model_8.png'),
os.path.join(example_path, 'model/model_2.png'),
os.path.join(example_path, 'model/model_7.png'),
os.path.join(example_path, 'model/model_4.png'),
os.path.join(example_path, 'model/model_5.png'),
])
with gr.Column():
garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384, value=default_garment)
example = gr.Examples(
inputs=garm_img,
examples_per_page=5,
examples=[
os.path.join(example_path, 'garment/00055_00.jpg'),
os.path.join(example_path, 'garment/07764_00.jpg'),
os.path.join(example_path, 'garment/03032_00.jpg'),
os.path.join(example_path, 'garment/048554_1.jpg'),
os.path.join(example_path, 'garment/049805_1.jpg'),
])
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)
with gr.Column():
run_button = gr.Button(value="Run")
n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
ips = [vton_img, garm_img, n_steps, image_scale, seed]
run_button.click(fn=run_viton, inputs=ips, outputs=result_gallery)
block.launch(server_name='0.0.0.0', server_port=7865, mcp_server=True)
|