import gradio as gr import os from PIL import Image import requests import base64 import io from dotenv import load_dotenv load_dotenv() example_path = os.path.join(os.path.dirname(__file__), 'examples') def image_to_base64(image_path): # Remove 'self' """Convert image file to base64 string""" with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode() def base64_to_image(base64_str, output_path): # Remove 'self' """Convert base64 string to image file""" image_data = base64.b64decode(base64_str) image = Image.open(io.BytesIO(image_data)) image.save(output_path) return image def download_image_from_url(url, output_path): """Download image from URL and save to local path""" try: response = requests.get(url, timeout=30) response.raise_for_status() # Save the image with open(output_path, 'wb') as f: f.write(response.content) # Verify it's a valid image image = Image.open(output_path) return output_path except Exception as e: print(f"Error downloading image from {url}: {str(e)}") return None def url_to_base64(url): """Convert image URL to base64 string""" try: response = requests.get(url, timeout=30) response.raise_for_status() return base64.b64encode(response.content).decode() except Exception as e: print(f"Error converting URL to base64: {str(e)}") return None def run_viton(model_image_path: str = None, garment_image_path: str = None, model_url: str = None, garment_url: str = None, n_steps=20, image_scale=2.0, seed=-1 ): """ Run the Virtual Try-On model with provided images path or URLs. """ if not model_image_path and not model_url: raise gr.Error("❌ Please provide either a model image file or URL") if not garment_image_path and not garment_url: raise gr.Error("❌ Please provide either a garment image file or URL") try: api_url = os.environ.get("SERVER_URL") if not api_url: raise gr.Error("❌ SERVER_URL not configured in environment variables") print(f"Using API URL: {api_url}") # Handle model image model_b64 = None if model_url and model_url.strip(): print(f"Using model URL: {model_url}") model_b64 = url_to_base64(model_url.strip()) if not model_b64: raise gr.Error("❌ Failed to load model image from URL. Please check the URL is valid.") elif model_image_path: print(f"Using model file: {model_image_path}") model_b64 = image_to_base64(model_image_path) # Handle garment image garment_b64 = None if garment_url and garment_url.strip(): print(f"Using garment URL: {garment_url}") garment_b64 = url_to_base64(garment_url.strip()) if not garment_b64: raise gr.Error("❌ Failed to load garment image from URL. Please check the URL is valid.") elif garment_image_path: print(f"Using garment file: {garment_image_path}") garment_b64 = image_to_base64(garment_image_path) if not model_b64 or not garment_b64: raise gr.Error("❌ Failed to process images. Please try again.") # Prepare request request_data = { "model_image_base64": model_b64, "garment_image_base64": garment_b64, "n_samples": 1, "n_steps": n_steps, "image_scale": image_scale, "seed": seed } # Send request response = requests.post(f"{api_url}/viton", json=request_data, timeout=300) print(f"Request sent to {api_url}/viton") print(f"Response status code: {response.status_code}") if response.status_code == 200: result = response.json() if result.get("error"): raise gr.Error(f"❌ Server error: {result['error']}") generated_images = [] for i, img_b64 in enumerate(result.get("images_base64", [])): output_path = f"ootd_output_{i}.png" img = base64_to_image(img_b64, output_path) generated_images.append(img) if not generated_images: raise gr.Error("❌ No images were generated. Please try again.") print(f"Successfully generated {len(generated_images)} images") return generated_images else: raise gr.Error(f"❌ Request failed with status code: {response.status_code}") except gr.Error: raise # Re-raise Gradio errors except Exception as e: print(f"Exception occurred: {str(e)}") raise gr.Error(f"❌ An unexpected error occurred: {str(e)}") def run_new_garment(model_image_path: str = None, garment_prompt: str = None, model_url: str = None, n_steps=20, image_scale=2.0, seed=-1 ): """ Run the Virtual Try-On model with provided model image and garment prompt. """ if not model_image_path and not model_url: raise gr.Error("❌ Please provide either a model image file or URL") if not garment_prompt or not garment_prompt.strip(): raise gr.Error("❌ Please provide a garment description") try: api_url = os.environ.get("SERVER_URL") if not api_url: raise gr.Error("❌ SERVER_URL not configured in environment variables") print(f"Using API URL: {api_url}") # Handle model image model_b64 = None if model_url and model_url.strip(): print(f"Using model URL: {model_url}") model_b64 = url_to_base64(model_url.strip()) if not model_b64: raise gr.Error("❌ Failed to load model image from URL. Please check the URL is valid.") elif model_image_path: print(f"Using model file: {model_image_path}") model_b64 = image_to_base64(model_image_path) if not model_b64: raise gr.Error("❌ Failed to process model image. Please try again.") # Prepare request request_data = { "model_image_base64": model_b64, "garment_prompt": garment_prompt.strip(), "n_samples": 1, "n_steps": n_steps, "image_scale": image_scale, "seed": seed } # Send request response = requests.post(f"{api_url}/new-garment", json=request_data, timeout=300) print(f"Request sent to {api_url}/new-garment") print(f"Response status code: {response.status_code}") if response.status_code == 200: result = response.json() if result.get("error"): raise gr.Error(f"❌ Server error: {result['error']}") generated_images = [] for i, img_b64 in enumerate(result.get("images_base64", [])): output_path = f"flux_output_{i}.png" img = base64_to_image(img_b64, output_path) generated_images.append(img) if not generated_images: raise gr.Error("❌ No images were generated. Please try again.") print(f"Successfully generated {len(generated_images)} images") return generated_images else: raise gr.Error(f"❌ Request failed with status code: {response.status_code}") except gr.Error: raise # Re-raise Gradio errors except Exception as e: print(f"Exception occurred: {str(e)}") raise gr.Error(f"❌ An unexpected error occurred: {str(e)}") block = gr.Blocks().queue() with block: with gr.Row(): gr.Markdown("# Virtual Try-On") with gr.Row(): with gr.Column(): gr.Markdown("### Provide image or URL of upper body photo") model_url = gr.Textbox( label="Enter Model Image URL", ) vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384) example = gr.Examples( inputs=vton_img, examples_per_page=4, examples=[ os.path.join(example_path, 'model/model_2.png'), os.path.join(example_path, 'model/model_7.png'), os.path.join(example_path, 'model/model_4.png'), os.path.join(example_path, 'model/model_5.png'), ]) with gr.Column(): gr.Markdown("### Provide image, URL or description of a garment") garment_url = gr.Textbox( label="Enter Garment Image URL", ) garment_promt = gr.Textbox( label="Describe Garment", ) garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384) example = gr.Examples( inputs=garm_img, examples_per_page=4, examples=[ os.path.join(example_path, 'garment/07764_00.jpg'), os.path.join(example_path, 'garment/03032_00.jpg'), os.path.join(example_path, 'garment/048554_1.jpg'), os.path.join(example_path, 'garment/049805_1.jpg'), ]) with gr.Column(): gr.Markdown("### 2D Result") result_gallery = gr.Gallery(label='Output 2D', show_label=False, elem_id="gallery", preview=True, scale=1) with gr.Column(): run_button = gr.Button(value="Try On with your garment") run_button2 = gr.Button(value="Try On with AI generated garment") n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1) image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1) seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) ips1 = [vton_img, garm_img, model_url, garment_url, n_steps, image_scale, seed] run_button.click(fn=run_viton, inputs=ips1, outputs=result_gallery) ips2 = [vton_img, garment_promt, model_url, n_steps, image_scale, seed] run_button2.click(fn=run_new_garment, inputs=ips2, outputs=result_gallery) block.launch(mcp_server=True)