|
import gradio as gr |
|
import os |
|
from PIL import Image |
|
import requests |
|
import base64 |
|
import io |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
example_path = os.path.join(os.path.dirname(__file__), 'examples') |
|
|
|
def image_to_base64(image_path): |
|
"""Convert image file to base64 string""" |
|
with open(image_path, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode() |
|
|
|
def base64_to_image(base64_str, output_path): |
|
"""Convert base64 string to image file""" |
|
image_data = base64.b64decode(base64_str) |
|
image = Image.open(io.BytesIO(image_data)) |
|
image.save(output_path) |
|
return image |
|
|
|
def download_image_from_url(url, output_path): |
|
"""Download image from URL and save to local path""" |
|
try: |
|
response = requests.get(url, timeout=30) |
|
response.raise_for_status() |
|
|
|
|
|
with open(output_path, 'wb') as f: |
|
f.write(response.content) |
|
|
|
|
|
image = Image.open(output_path) |
|
return output_path |
|
except Exception as e: |
|
print(f"Error downloading image from {url}: {str(e)}") |
|
return None |
|
|
|
def url_to_base64(url): |
|
"""Convert image URL to base64 string""" |
|
try: |
|
response = requests.get(url, timeout=30) |
|
response.raise_for_status() |
|
return base64.b64encode(response.content).decode() |
|
except Exception as e: |
|
print(f"Error converting URL to base64: {str(e)}") |
|
return None |
|
|
|
|
|
def run_viton(model_image_path: str = None, |
|
garment_image_path: str = None, |
|
model_url: str = None, |
|
garment_url: str = None, |
|
n_steps=20, |
|
image_scale=2.0, |
|
seed=-1 |
|
): |
|
""" |
|
Run the Virtual Try-On model with provided images path or URLs. |
|
""" |
|
if not model_image_path and not model_url: |
|
raise gr.Error("β Please provide either a model image file or URL") |
|
if not garment_image_path and not garment_url: |
|
raise gr.Error("β Please provide either a garment image file or URL") |
|
|
|
try: |
|
api_url = os.environ.get("SERVER_URL") |
|
if not api_url: |
|
raise gr.Error("β SERVER_URL not configured in environment variables") |
|
|
|
print(f"Using API URL: {api_url}") |
|
|
|
|
|
model_b64 = None |
|
if model_url and model_url.strip(): |
|
print(f"Using model URL: {model_url}") |
|
model_b64 = url_to_base64(model_url.strip()) |
|
if not model_b64: |
|
raise gr.Error("β Failed to load model image from URL. Please check the URL is valid.") |
|
elif model_image_path: |
|
print(f"Using model file: {model_image_path}") |
|
model_b64 = image_to_base64(model_image_path) |
|
|
|
|
|
garment_b64 = None |
|
if garment_url and garment_url.strip(): |
|
print(f"Using garment URL: {garment_url}") |
|
garment_b64 = url_to_base64(garment_url.strip()) |
|
if not garment_b64: |
|
raise gr.Error("β Failed to load garment image from URL. Please check the URL is valid.") |
|
elif garment_image_path: |
|
print(f"Using garment file: {garment_image_path}") |
|
garment_b64 = image_to_base64(garment_image_path) |
|
|
|
if not model_b64 or not garment_b64: |
|
raise gr.Error("β Failed to process images. Please try again.") |
|
|
|
|
|
request_data = { |
|
"model_image_base64": model_b64, |
|
"garment_image_base64": garment_b64, |
|
"n_samples": 1, |
|
"n_steps": n_steps, |
|
"image_scale": image_scale, |
|
"seed": seed |
|
} |
|
|
|
|
|
response = requests.post(f"{api_url}/viton", |
|
json=request_data, |
|
timeout=300) |
|
|
|
print(f"Request sent to {api_url}/viton") |
|
print(f"Response status code: {response.status_code}") |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
if result.get("error"): |
|
raise gr.Error(f"β Server error: {result['error']}") |
|
|
|
generated_images = [] |
|
for i, img_b64 in enumerate(result.get("images_base64", [])): |
|
output_path = f"ootd_output_{i}.png" |
|
img = base64_to_image(img_b64, output_path) |
|
generated_images.append(img) |
|
|
|
if not generated_images: |
|
raise gr.Error("β No images were generated. Please try again.") |
|
|
|
print(f"Successfully generated {len(generated_images)} images") |
|
return generated_images |
|
else: |
|
raise gr.Error(f"β Request failed with status code: {response.status_code}") |
|
|
|
except gr.Error: |
|
raise |
|
except Exception as e: |
|
print(f"Exception occurred: {str(e)}") |
|
raise gr.Error(f"β An unexpected error occurred: {str(e)}") |
|
|
|
def run_new_garment(model_image_path: str = None, |
|
garment_prompt: str = None, |
|
model_url: str = None, |
|
n_steps=20, |
|
image_scale=2.0, |
|
seed=-1 |
|
): |
|
""" |
|
Run the Virtual Try-On model with provided model image and garment prompt. |
|
""" |
|
if not model_image_path and not model_url: |
|
raise gr.Error("β Please provide either a model image file or URL") |
|
if not garment_prompt or not garment_prompt.strip(): |
|
raise gr.Error("β Please provide a garment description") |
|
|
|
try: |
|
api_url = os.environ.get("SERVER_URL") |
|
if not api_url: |
|
raise gr.Error("β SERVER_URL not configured in environment variables") |
|
|
|
print(f"Using API URL: {api_url}") |
|
|
|
|
|
model_b64 = None |
|
if model_url and model_url.strip(): |
|
print(f"Using model URL: {model_url}") |
|
model_b64 = url_to_base64(model_url.strip()) |
|
if not model_b64: |
|
raise gr.Error("β Failed to load model image from URL. Please check the URL is valid.") |
|
elif model_image_path: |
|
print(f"Using model file: {model_image_path}") |
|
model_b64 = image_to_base64(model_image_path) |
|
|
|
if not model_b64: |
|
raise gr.Error("β Failed to process model image. Please try again.") |
|
|
|
|
|
request_data = { |
|
"model_image_base64": model_b64, |
|
"garment_prompt": garment_prompt.strip(), |
|
"n_samples": 1, |
|
"n_steps": n_steps, |
|
"image_scale": image_scale, |
|
"seed": seed |
|
} |
|
|
|
|
|
response = requests.post(f"{api_url}/new-garment", |
|
json=request_data, |
|
timeout=300) |
|
|
|
print(f"Request sent to {api_url}/new-garment") |
|
print(f"Response status code: {response.status_code}") |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
if result.get("error"): |
|
raise gr.Error(f"β Server error: {result['error']}") |
|
|
|
generated_images = [] |
|
for i, img_b64 in enumerate(result.get("images_base64", [])): |
|
output_path = f"flux_output_{i}.png" |
|
img = base64_to_image(img_b64, output_path) |
|
generated_images.append(img) |
|
|
|
if not generated_images: |
|
raise gr.Error("β No images were generated. Please try again.") |
|
|
|
print(f"Successfully generated {len(generated_images)} images") |
|
return generated_images |
|
else: |
|
raise gr.Error(f"β Request failed with status code: {response.status_code}") |
|
|
|
except gr.Error: |
|
raise |
|
except Exception as e: |
|
print(f"Exception occurred: {str(e)}") |
|
raise gr.Error(f"β An unexpected error occurred: {str(e)}") |
|
|
|
block = gr.Blocks().queue() |
|
with block: |
|
with gr.Row(): |
|
gr.Markdown("# Virtual Try-On") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Provide image or URL of upper body photo") |
|
model_url = gr.Textbox( |
|
label="Enter Model Image URL", |
|
) |
|
vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384) |
|
example = gr.Examples( |
|
inputs=vton_img, |
|
examples_per_page=4, |
|
examples=[ |
|
os.path.join(example_path, 'model/model_2.png'), |
|
os.path.join(example_path, 'model/model_7.png'), |
|
os.path.join(example_path, 'model/model_4.png'), |
|
os.path.join(example_path, 'model/model_5.png'), |
|
]) |
|
with gr.Column(): |
|
gr.Markdown("### Provide image, URL or description of a garment") |
|
garment_url = gr.Textbox( |
|
label="Enter Garment Image URL", |
|
) |
|
garment_promt = gr.Textbox( |
|
label="Describe Garment", |
|
) |
|
garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384) |
|
example = gr.Examples( |
|
inputs=garm_img, |
|
examples_per_page=4, |
|
examples=[ |
|
os.path.join(example_path, 'garment/07764_00.jpg'), |
|
os.path.join(example_path, 'garment/03032_00.jpg'), |
|
os.path.join(example_path, 'garment/048554_1.jpg'), |
|
os.path.join(example_path, 'garment/049805_1.jpg'), |
|
]) |
|
with gr.Column(): |
|
gr.Markdown("### 2D Result") |
|
result_gallery = gr.Gallery(label='Output 2D', show_label=False, elem_id="gallery", preview=True, scale=1) |
|
with gr.Column(): |
|
run_button = gr.Button(value="Try On with your garment") |
|
run_button2 = gr.Button(value="Try On with AI generated garment") |
|
n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1) |
|
image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1) |
|
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) |
|
|
|
ips1 = [vton_img, garm_img, model_url, garment_url, n_steps, image_scale, seed] |
|
run_button.click(fn=run_viton, inputs=ips1, outputs=result_gallery) |
|
ips2 = [vton_img, garment_promt, model_url, n_steps, image_scale, seed] |
|
run_button2.click(fn=run_new_garment, inputs=ips2, outputs=result_gallery) |
|
|
|
block.launch(mcp_server=True) |
|
|