Spaces:
Running
Running
import os | |
import time | |
import torch | |
from models import UNet | |
from test_functions import process_image | |
from PIL import Image | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
MODEL_PATH = "model/best_unet_model.pth" | |
os.makedirs("model", exist_ok=True) | |
if not os.path.exists(MODEL_PATH): | |
print("Starting model download at", time.strftime("%Y-%m-%d %H:%M:%S")) | |
path = hf_hub_download(repo_id="Robys01/face-aging", filename="best_unet_model.pth", local_dir="model", cache_dir="model") | |
print(f"Model downloaded to {path}") | |
model = UNet() | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"), weights_only=False)) | |
model.eval() | |
def age_image(image: Image.Image, source_age: int, target_age: int) -> Image.Image: | |
if image.mode not in ["RGB", "L"]: | |
print(f"Converting image from {image.mode} to RGB") | |
image = image.convert("RGB") | |
processed_image = process_image(model, image, source_age, target_age) | |
return processed_image | |
# Pre-load the example images as PIL objects | |
example1 = Image.open("examples/girl.jpg") | |
example2 = Image.open("examples/trump.jpg") | |
iface = gr.Interface( | |
fn=age_image, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image"), | |
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose the current age"), | |
gr.Slider(10, 90, value=70, step=1, label="Target age", info="Choose the desired age") | |
], | |
outputs=gr.Image(type="pil", label="Aged Image"), | |
examples=[ | |
[example1, 14, 50], | |
[example2, 74, 30], | |
], | |
title="Face Aging Demo", | |
description="Upload an image along with a source age approximation and a target age to generate an aged version of the face." | |
) | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0", server_port=7000) |