import gradio as gr import os from PIL import Image import requests import base64 import io from dotenv import load_dotenv load_dotenv() example_path = os.path.join(os.path.dirname(__file__), 'examples') def image_to_base64(image_path): # Remove 'self' """Convert image file to base64 string""" with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode() def base64_to_image(base64_str, output_path): # Remove 'self' """Convert base64 string to image file""" image_data = base64.b64decode(base64_str) image = Image.open(io.BytesIO(image_data)) image.save(output_path) return image def download_image_from_url(url, output_path): """Download image from URL and save to local path""" try: response = requests.get(url, timeout=30) response.raise_for_status() # Save the image with open(output_path, 'wb') as f: f.write(response.content) # Verify it's a valid image image = Image.open(output_path) return output_path except Exception as e: print(f"Error downloading image from {url}: {str(e)}") return None def url_to_base64(url): """Convert image URL to base64 string""" try: response = requests.get(url, timeout=30) response.raise_for_status() return base64.b64encode(response.content).decode() except Exception as e: print(f"Error converting URL to base64: {str(e)}") return None def run_viton(model_image_path, garment_image_path, model_url, garment_url, n_steps=20, image_scale=2.0, seed=-1): try: api_url = os.environ.get("SERVER_URL") print(f"Using API URL: {api_url}") # Add this to debug # Determine which inputs to use (file upload or URL) model_b64 = None garment_b64 = None # Handle model image if model_url and model_url.strip(): print(f"Using model URL: {model_url}") model_b64 = url_to_base64(model_url.strip()) elif model_image_path: print(f"Using model file: {model_image_path}") model_b64 = image_to_base64(model_image_path) # Handle garment image if garment_url and garment_url.strip(): print(f"Using garment URL: {garment_url}") garment_b64 = url_to_base64(garment_url.strip()) elif garment_image_path: print(f"Using garment file: {garment_image_path}") garment_b64 = image_to_base64(garment_image_path) # Check if we have both images if not model_b64 or not garment_b64: print("Error: Missing model or garment image") return [] # Prepare request request_data = { "model_image_base64": model_b64, "garment_image_base64": garment_b64, "n_samples": 1, "n_steps": n_steps, "image_scale": image_scale, "seed": seed } # Send request response = requests.post(f"{api_url}/viton", json=request_data, timeout=300) print(f"Request sent to {api_url}/viton") print(f"Response status code: {response.status_code}") if response.status_code == 200: result = response.json() if result.get("error"): print(f"Error: {result['error']}") return [] generated_images = [] for i, img_b64 in enumerate(result.get("images_base64", [])): output_path = f"ootd_output_{i}.png" img = base64_to_image(img_b64, output_path) # Remove 'self.' generated_images.append(img) print(f"Successfully generated {len(generated_images)} images") return generated_images else: print(f"Request failed with status code: {response.status_code}") return [] # Fix: was missing 'return' except Exception as e: print(f"Exception occurred: {str(e)}") # Add this return [] # Fix: should return list, not dict for gallery block = gr.Blocks().queue() with block: with gr.Row(): gr.Markdown("# Virtual Try-On") with gr.Row(): gr.Markdown("**Instructions:** You can either upload images using the file upload interface or provide direct URLs to images. URL inputs will take priority over uploaded files.") with gr.Row(): with gr.Column(): model_url = gr.Textbox( label="Enter Model Image URL", ) vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384) example = gr.Examples( inputs=vton_img, examples_per_page=5, examples=[ os.path.join(example_path, 'model/model_8.png'), os.path.join(example_path, 'model/model_2.png'), os.path.join(example_path, 'model/model_7.png'), os.path.join(example_path, 'model/model_4.png'), os.path.join(example_path, 'model/model_5.png'), ]) with gr.Column(): garment_url = gr.Textbox( label="Enter Garment Image URL", ) garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384) example = gr.Examples( inputs=garm_img, examples_per_page=5, examples=[ os.path.join(example_path, 'garment/00055_00.jpg'), os.path.join(example_path, 'garment/07764_00.jpg'), os.path.join(example_path, 'garment/03032_00.jpg'), os.path.join(example_path, 'garment/048554_1.jpg'), os.path.join(example_path, 'garment/049805_1.jpg'), ]) with gr.Column(): result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1) with gr.Column(): run_button = gr.Button(value="Run") n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1) image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1) seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) ips = [vton_img, garm_img, model_url, garment_url, n_steps, image_scale, seed] run_button.click(fn=run_viton, inputs=ips, outputs=result_gallery) block.launch(mcp_server=True)