|
import gradio as gr |
|
import os |
|
from PIL import Image |
|
import requests |
|
import base64 |
|
import io |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
example_path = os.path.join(os.path.dirname(__file__), 'examples') |
|
|
|
def image_to_base64(image_path): |
|
"""Convert image file to base64 string""" |
|
with open(image_path, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode() |
|
|
|
def base64_to_image(base64_str, output_path): |
|
"""Convert base64 string to image file""" |
|
image_data = base64.b64decode(base64_str) |
|
image = Image.open(io.BytesIO(image_data)) |
|
image.save(output_path) |
|
return image |
|
|
|
def download_image_from_url(url, output_path): |
|
"""Download image from URL and save to local path""" |
|
try: |
|
response = requests.get(url, timeout=30) |
|
response.raise_for_status() |
|
|
|
|
|
with open(output_path, 'wb') as f: |
|
f.write(response.content) |
|
|
|
|
|
image = Image.open(output_path) |
|
return output_path |
|
except Exception as e: |
|
print(f"Error downloading image from {url}: {str(e)}") |
|
return None |
|
|
|
def url_to_base64(url): |
|
"""Convert image URL to base64 string""" |
|
try: |
|
response = requests.get(url, timeout=30) |
|
response.raise_for_status() |
|
return base64.b64encode(response.content).decode() |
|
except Exception as e: |
|
print(f"Error converting URL to base64: {str(e)}") |
|
return None |
|
|
|
def run_viton(model_image_path, garment_image_path, model_url, garment_url, |
|
n_steps=20, image_scale=2.0, seed=-1): |
|
try: |
|
api_url = os.environ.get("SERVER_URL") |
|
print(f"Using API URL: {api_url}") |
|
|
|
|
|
model_b64 = None |
|
garment_b64 = None |
|
|
|
|
|
if model_url and model_url.strip(): |
|
print(f"Using model URL: {model_url}") |
|
model_b64 = url_to_base64(model_url.strip()) |
|
elif model_image_path: |
|
print(f"Using model file: {model_image_path}") |
|
model_b64 = image_to_base64(model_image_path) |
|
|
|
|
|
if garment_url and garment_url.strip(): |
|
print(f"Using garment URL: {garment_url}") |
|
garment_b64 = url_to_base64(garment_url.strip()) |
|
elif garment_image_path: |
|
print(f"Using garment file: {garment_image_path}") |
|
garment_b64 = image_to_base64(garment_image_path) |
|
|
|
|
|
if not model_b64 or not garment_b64: |
|
print("Error: Missing model or garment image") |
|
return [] |
|
|
|
|
|
request_data = { |
|
"model_image_base64": model_b64, |
|
"garment_image_base64": garment_b64, |
|
"n_samples": 1, |
|
"n_steps": n_steps, |
|
"image_scale": image_scale, |
|
"seed": seed |
|
} |
|
|
|
|
|
response = requests.post(f"{api_url}/viton", |
|
json=request_data, |
|
timeout=300) |
|
|
|
print(f"Request sent to {api_url}/viton") |
|
print(f"Response status code: {response.status_code}") |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
if result.get("error"): |
|
print(f"Error: {result['error']}") |
|
return [] |
|
|
|
generated_images = [] |
|
for i, img_b64 in enumerate(result.get("images_base64", [])): |
|
output_path = f"ootd_output_{i}.png" |
|
img = base64_to_image(img_b64, output_path) |
|
generated_images.append(img) |
|
|
|
print(f"Successfully generated {len(generated_images)} images") |
|
return generated_images |
|
else: |
|
print(f"Request failed with status code: {response.status_code}") |
|
return [] |
|
|
|
except Exception as e: |
|
print(f"Exception occurred: {str(e)}") |
|
return [] |
|
|
|
block = gr.Blocks().queue() |
|
with block: |
|
with gr.Row(): |
|
gr.Markdown("# Virtual Try-On") |
|
with gr.Row(): |
|
gr.Markdown("**Instructions:** You can either upload images using the file upload interface or provide direct URLs to images. URL inputs will take priority over uploaded files.") |
|
with gr.Row(): |
|
with gr.Column(): |
|
model_url = gr.Textbox( |
|
label="Enter Model Image URL", |
|
) |
|
vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384) |
|
example = gr.Examples( |
|
inputs=vton_img, |
|
examples_per_page=5, |
|
examples=[ |
|
os.path.join(example_path, 'model/model_8.png'), |
|
os.path.join(example_path, 'model/model_2.png'), |
|
os.path.join(example_path, 'model/model_7.png'), |
|
os.path.join(example_path, 'model/model_4.png'), |
|
os.path.join(example_path, 'model/model_5.png'), |
|
]) |
|
with gr.Column(): |
|
garment_url = gr.Textbox( |
|
label="Enter Garment Image URL", |
|
) |
|
garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384) |
|
example = gr.Examples( |
|
inputs=garm_img, |
|
examples_per_page=5, |
|
examples=[ |
|
os.path.join(example_path, 'garment/00055_00.jpg'), |
|
os.path.join(example_path, 'garment/07764_00.jpg'), |
|
os.path.join(example_path, 'garment/03032_00.jpg'), |
|
os.path.join(example_path, 'garment/048554_1.jpg'), |
|
os.path.join(example_path, 'garment/049805_1.jpg'), |
|
]) |
|
with gr.Column(): |
|
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1) |
|
with gr.Column(): |
|
run_button = gr.Button(value="Run") |
|
n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1) |
|
image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1) |
|
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) |
|
|
|
ips = [vton_img, garm_img, model_url, garment_url, n_steps, image_scale, seed] |
|
run_button.click(fn=run_viton, inputs=ips, outputs=result_gallery) |
|
|
|
block.launch(mcp_server=True) |
|
|