haizad's picture
Add project files with proper Git LFS tracking for binary files
d2bc90c
raw
history blame
4.85 kB
import gradio as gr
import os
from PIL import Image
import requests
import base64
import io
from dotenv import load_dotenv
load_dotenv()
example_path = os.path.join(os.path.dirname(__file__), 'examples')
def image_to_base64(image_path): # Remove 'self'
"""Convert image file to base64 string"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode()
def base64_to_image(base64_str, output_path): # Remove 'self'
"""Convert base64 string to image file"""
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data))
image.save(output_path)
return image
def run_viton(model_image_path, garment_image_path,
n_steps=20, image_scale=2.0, seed=-1):
try:
api_url = os.environ.get("SERVER_URL")
print(f"Using API URL: {api_url}") # Add this to debug
# Convert images to base64 (remove 'self.')
model_b64 = image_to_base64(model_image_path)
garment_b64 = image_to_base64(garment_image_path)
# Prepare request
request_data = {
"model_image_base64": model_b64,
"garment_image_base64": garment_b64,
"n_samples": 1,
"n_steps": n_steps,
"image_scale": image_scale,
"seed": seed
}
# Send request
response = requests.post(f"{api_url}/viton",
json=request_data,
timeout=300)
print(f"Request sent to {api_url}/viton")
print(f"Response status code: {response.status_code}")
if response.status_code == 200:
result = response.json()
if result.get("error"):
print(f"Error: {result['error']}")
return []
generated_images = []
for i, img_b64 in enumerate(result.get("images_base64", [])):
output_path = f"ootd_output_{i}.png"
img = base64_to_image(img_b64, output_path) # Remove 'self.'
generated_images.append(img)
print(f"Successfully generated {len(generated_images)} images")
return generated_images
else:
print(f"Request failed with status code: {response.status_code}")
return [] # Fix: was missing 'return'
except Exception as e:
print(f"Exception occurred: {str(e)}") # Add this
return [] # Fix: should return list, not dict for gallery
block = gr.Blocks().queue()
default_model = os.path.join(example_path, 'model/model_8.png')
default_garment = os.path.join(example_path, 'garment/00055_00.jpg')
with block:
with gr.Row():
gr.Markdown("# Virtual Try-On")
with gr.Row():
with gr.Column():
vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384, value=default_model)
example = gr.Examples(
inputs=vton_img,
examples_per_page=5,
examples=[
os.path.join(example_path, 'model/model_8.png'),
os.path.join(example_path, 'model/model_2.png'),
os.path.join(example_path, 'model/model_7.png'),
os.path.join(example_path, 'model/model_4.png'),
os.path.join(example_path, 'model/model_5.png'),
])
with gr.Column():
garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384, value=default_garment)
example = gr.Examples(
inputs=garm_img,
examples_per_page=5,
examples=[
os.path.join(example_path, 'garment/00055_00.jpg'),
os.path.join(example_path, 'garment/07764_00.jpg'),
os.path.join(example_path, 'garment/03032_00.jpg'),
os.path.join(example_path, 'garment/048554_1.jpg'),
os.path.join(example_path, 'garment/049805_1.jpg'),
])
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)
with gr.Column():
run_button = gr.Button(value="Run")
n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
ips = [vton_img, garm_img, n_steps, image_scale, seed]
run_button.click(fn=run_viton, inputs=ips, outputs=result_gallery)
block.launch(server_name='0.0.0.0', server_port=7865, mcp_server=True)