haizad's picture
add error display
588bc14
import gradio as gr
import os
from PIL import Image
import requests
import base64
import io
from dotenv import load_dotenv
load_dotenv()
example_path = os.path.join(os.path.dirname(__file__), 'examples')
def image_to_base64(image_path): # Remove 'self'
"""Convert image file to base64 string"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode()
def base64_to_image(base64_str, output_path): # Remove 'self'
"""Convert base64 string to image file"""
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data))
image.save(output_path)
return image
def download_image_from_url(url, output_path):
"""Download image from URL and save to local path"""
try:
response = requests.get(url, timeout=30)
response.raise_for_status()
# Save the image
with open(output_path, 'wb') as f:
f.write(response.content)
# Verify it's a valid image
image = Image.open(output_path)
return output_path
except Exception as e:
print(f"Error downloading image from {url}: {str(e)}")
return None
def url_to_base64(url):
"""Convert image URL to base64 string"""
try:
response = requests.get(url, timeout=30)
response.raise_for_status()
return base64.b64encode(response.content).decode()
except Exception as e:
print(f"Error converting URL to base64: {str(e)}")
return None
def run_viton(model_image_path: str = None,
garment_image_path: str = None,
model_url: str = None,
garment_url: str = None,
n_steps=20,
image_scale=2.0,
seed=-1
):
"""
Run the Virtual Try-On model with provided images path or URLs.
"""
if not model_image_path and not model_url:
raise gr.Error("❌ Please provide either a model image file or URL")
if not garment_image_path and not garment_url:
raise gr.Error("❌ Please provide either a garment image file or URL")
try:
api_url = os.environ.get("SERVER_URL")
if not api_url:
raise gr.Error("❌ SERVER_URL not configured in environment variables")
print(f"Using API URL: {api_url}")
# Handle model image
model_b64 = None
if model_url and model_url.strip():
print(f"Using model URL: {model_url}")
model_b64 = url_to_base64(model_url.strip())
if not model_b64:
raise gr.Error("❌ Failed to load model image from URL. Please check the URL is valid.")
elif model_image_path:
print(f"Using model file: {model_image_path}")
model_b64 = image_to_base64(model_image_path)
# Handle garment image
garment_b64 = None
if garment_url and garment_url.strip():
print(f"Using garment URL: {garment_url}")
garment_b64 = url_to_base64(garment_url.strip())
if not garment_b64:
raise gr.Error("❌ Failed to load garment image from URL. Please check the URL is valid.")
elif garment_image_path:
print(f"Using garment file: {garment_image_path}")
garment_b64 = image_to_base64(garment_image_path)
if not model_b64 or not garment_b64:
raise gr.Error("❌ Failed to process images. Please try again.")
# Prepare request
request_data = {
"model_image_base64": model_b64,
"garment_image_base64": garment_b64,
"n_samples": 1,
"n_steps": n_steps,
"image_scale": image_scale,
"seed": seed
}
# Send request
response = requests.post(f"{api_url}/viton",
json=request_data,
timeout=300)
print(f"Request sent to {api_url}/viton")
print(f"Response status code: {response.status_code}")
if response.status_code == 200:
result = response.json()
if result.get("error"):
raise gr.Error(f"❌ Server error: {result['error']}")
generated_images = []
for i, img_b64 in enumerate(result.get("images_base64", [])):
output_path = f"ootd_output_{i}.png"
img = base64_to_image(img_b64, output_path)
generated_images.append(img)
if not generated_images:
raise gr.Error("❌ No images were generated. Please try again.")
print(f"Successfully generated {len(generated_images)} images")
return generated_images
else:
raise gr.Error(f"❌ Request failed with status code: {response.status_code}")
except gr.Error:
raise # Re-raise Gradio errors
except Exception as e:
print(f"Exception occurred: {str(e)}")
raise gr.Error(f"❌ An unexpected error occurred: {str(e)}")
def run_new_garment(model_image_path: str = None,
garment_prompt: str = None,
model_url: str = None,
n_steps=20,
image_scale=2.0,
seed=-1
):
"""
Run the Virtual Try-On model with provided model image and garment prompt.
"""
if not model_image_path and not model_url:
raise gr.Error("❌ Please provide either a model image file or URL")
if not garment_prompt or not garment_prompt.strip():
raise gr.Error("❌ Please provide a garment description")
try:
api_url = os.environ.get("SERVER_URL")
if not api_url:
raise gr.Error("❌ SERVER_URL not configured in environment variables")
print(f"Using API URL: {api_url}")
# Handle model image
model_b64 = None
if model_url and model_url.strip():
print(f"Using model URL: {model_url}")
model_b64 = url_to_base64(model_url.strip())
if not model_b64:
raise gr.Error("❌ Failed to load model image from URL. Please check the URL is valid.")
elif model_image_path:
print(f"Using model file: {model_image_path}")
model_b64 = image_to_base64(model_image_path)
if not model_b64:
raise gr.Error("❌ Failed to process model image. Please try again.")
# Prepare request
request_data = {
"model_image_base64": model_b64,
"garment_prompt": garment_prompt.strip(),
"n_samples": 1,
"n_steps": n_steps,
"image_scale": image_scale,
"seed": seed
}
# Send request
response = requests.post(f"{api_url}/new-garment",
json=request_data,
timeout=300)
print(f"Request sent to {api_url}/new-garment")
print(f"Response status code: {response.status_code}")
if response.status_code == 200:
result = response.json()
if result.get("error"):
raise gr.Error(f"❌ Server error: {result['error']}")
generated_images = []
for i, img_b64 in enumerate(result.get("images_base64", [])):
output_path = f"flux_output_{i}.png"
img = base64_to_image(img_b64, output_path)
generated_images.append(img)
if not generated_images:
raise gr.Error("❌ No images were generated. Please try again.")
print(f"Successfully generated {len(generated_images)} images")
return generated_images
else:
raise gr.Error(f"❌ Request failed with status code: {response.status_code}")
except gr.Error:
raise # Re-raise Gradio errors
except Exception as e:
print(f"Exception occurred: {str(e)}")
raise gr.Error(f"❌ An unexpected error occurred: {str(e)}")
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("# Virtual Try-On")
with gr.Row():
with gr.Column():
gr.Markdown("### Provide image or URL of upper body photo")
model_url = gr.Textbox(
label="Enter Model Image URL",
)
vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384)
example = gr.Examples(
inputs=vton_img,
examples_per_page=4,
examples=[
os.path.join(example_path, 'model/model_2.png'),
os.path.join(example_path, 'model/model_7.png'),
os.path.join(example_path, 'model/model_4.png'),
os.path.join(example_path, 'model/model_5.png'),
])
with gr.Column():
gr.Markdown("### Provide image, URL or description of a garment")
garment_url = gr.Textbox(
label="Enter Garment Image URL",
)
garment_promt = gr.Textbox(
label="Describe Garment",
)
garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384)
example = gr.Examples(
inputs=garm_img,
examples_per_page=4,
examples=[
os.path.join(example_path, 'garment/07764_00.jpg'),
os.path.join(example_path, 'garment/03032_00.jpg'),
os.path.join(example_path, 'garment/048554_1.jpg'),
os.path.join(example_path, 'garment/049805_1.jpg'),
])
with gr.Column():
gr.Markdown("### 2D Result")
result_gallery = gr.Gallery(label='Output 2D', show_label=False, elem_id="gallery", preview=True, scale=1)
with gr.Column():
run_button = gr.Button(value="Try On with your garment")
run_button2 = gr.Button(value="Try On with AI generated garment")
n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
ips1 = [vton_img, garm_img, model_url, garment_url, n_steps, image_scale, seed]
run_button.click(fn=run_viton, inputs=ips1, outputs=result_gallery)
ips2 = [vton_img, garment_promt, model_url, n_steps, image_scale, seed]
run_button2.click(fn=run_new_garment, inputs=ips2, outputs=result_gallery)
block.launch(mcp_server=True)