|
import gradio as gr |
|
import os |
|
from PIL import Image |
|
import requests |
|
import base64 |
|
import io |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
example_path = os.path.join(os.path.dirname(__file__), 'examples') |
|
|
|
def image_to_base64(image_path): |
|
"""Convert image file to base64 string""" |
|
with open(image_path, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode() |
|
|
|
def base64_to_image(base64_str, output_path): |
|
"""Convert base64 string to image file""" |
|
image_data = base64.b64decode(base64_str) |
|
image = Image.open(io.BytesIO(image_data)) |
|
image.save(output_path) |
|
return image |
|
|
|
def run_viton(model_image_path, garment_image_path, |
|
n_steps=20, image_scale=2.0, seed=-1): |
|
try: |
|
api_url = os.environ.get("SERVER_URL") |
|
print(f"Using API URL: {api_url}") |
|
|
|
|
|
model_b64 = image_to_base64(model_image_path) |
|
garment_b64 = image_to_base64(garment_image_path) |
|
|
|
|
|
request_data = { |
|
"model_image_base64": model_b64, |
|
"garment_image_base64": garment_b64, |
|
"n_samples": 1, |
|
"n_steps": n_steps, |
|
"image_scale": image_scale, |
|
"seed": seed |
|
} |
|
|
|
|
|
response = requests.post(f"{api_url}/viton", |
|
json=request_data, |
|
timeout=300) |
|
|
|
print(f"Request sent to {api_url}/viton") |
|
print(f"Response status code: {response.status_code}") |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
if result.get("error"): |
|
print(f"Error: {result['error']}") |
|
return [] |
|
|
|
generated_images = [] |
|
for i, img_b64 in enumerate(result.get("images_base64", [])): |
|
output_path = f"ootd_output_{i}.png" |
|
img = base64_to_image(img_b64, output_path) |
|
generated_images.append(img) |
|
|
|
print(f"Successfully generated {len(generated_images)} images") |
|
return generated_images |
|
else: |
|
print(f"Request failed with status code: {response.status_code}") |
|
return [] |
|
|
|
except Exception as e: |
|
print(f"Exception occurred: {str(e)}") |
|
return [] |
|
|
|
block = gr.Blocks().queue() |
|
default_model = os.path.join(example_path, 'model/model_8.png') |
|
default_garment = os.path.join(example_path, 'garment/00055_00.jpg') |
|
with block: |
|
with gr.Row(): |
|
gr.Markdown("# Virtual Try-On") |
|
with gr.Row(): |
|
with gr.Column(): |
|
vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384, value=default_model) |
|
example = gr.Examples( |
|
inputs=vton_img, |
|
examples_per_page=5, |
|
examples=[ |
|
os.path.join(example_path, 'model/model_8.png'), |
|
os.path.join(example_path, 'model/model_2.png'), |
|
os.path.join(example_path, 'model/model_7.png'), |
|
os.path.join(example_path, 'model/model_4.png'), |
|
os.path.join(example_path, 'model/model_5.png'), |
|
]) |
|
with gr.Column(): |
|
garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384, value=default_garment) |
|
example = gr.Examples( |
|
inputs=garm_img, |
|
examples_per_page=5, |
|
examples=[ |
|
os.path.join(example_path, 'garment/00055_00.jpg'), |
|
os.path.join(example_path, 'garment/07764_00.jpg'), |
|
os.path.join(example_path, 'garment/03032_00.jpg'), |
|
os.path.join(example_path, 'garment/048554_1.jpg'), |
|
os.path.join(example_path, 'garment/049805_1.jpg'), |
|
]) |
|
with gr.Column(): |
|
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1) |
|
with gr.Column(): |
|
run_button = gr.Button(value="Run") |
|
n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1) |
|
image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1) |
|
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) |
|
|
|
ips = [vton_img, garm_img, n_steps, image_scale, seed] |
|
run_button.click(fn=run_viton, inputs=ips, outputs=result_gallery) |
|
|
|
block.launch(server_name='0.0.0.0', server_port=7865, mcp_server=True) |
|
|