Face-Aging / app.py
Robys01's picture
Pre-load example images for the face aging demo to solve gradio move issue.
3b679af
import os
import time
import torch
from models import UNet
from test_functions import process_image
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
MODEL_PATH = "model/best_unet_model.pth"
os.makedirs("model", exist_ok=True)
if not os.path.exists(MODEL_PATH):
print("Starting model download at", time.strftime("%Y-%m-%d %H:%M:%S"))
path = hf_hub_download(repo_id="Robys01/face-aging", filename="best_unet_model.pth", local_dir="model", cache_dir="model")
print(f"Model downloaded to {path}")
model = UNet()
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"), weights_only=False))
model.eval()
def age_image(image: Image.Image, source_age: int, target_age: int) -> Image.Image:
if image.mode not in ["RGB", "L"]:
print(f"Converting image from {image.mode} to RGB")
image = image.convert("RGB")
processed_image = process_image(model, image, source_age, target_age)
return processed_image
# Pre-load the example images as PIL objects
example1 = Image.open("examples/girl.jpg")
example2 = Image.open("examples/trump.jpg")
iface = gr.Interface(
fn=age_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose the current age"),
gr.Slider(10, 90, value=70, step=1, label="Target age", info="Choose the desired age")
],
outputs=gr.Image(type="pil", label="Aged Image"),
examples=[
[example1, 14, 50],
[example2, 74, 30],
],
title="Face Aging Demo",
description="Upload an image along with a source age approximation and a target age to generate an aged version of the face."
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7000)