import os import time import torch from models import UNet from test_functions import process_image from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download MODEL_PATH = "model/best_unet_model.pth" os.makedirs("model", exist_ok=True) if not os.path.exists(MODEL_PATH): print("Starting model download at", time.strftime("%Y-%m-%d %H:%M:%S")) path = hf_hub_download(repo_id="Robys01/face-aging", filename="best_unet_model.pth", local_dir="model", cache_dir="model") print(f"Model downloaded to {path}") model = UNet() model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"), weights_only=False)) model.eval() def age_image(image: Image.Image, source_age: int, target_age: int) -> Image.Image: if image.mode not in ["RGB", "L"]: print(f"Converting image from {image.mode} to RGB") image = image.convert("RGB") processed_image = process_image(model, image, source_age, target_age) return processed_image # Pre-load the example images as PIL objects example1 = Image.open("examples/girl.jpg") example2 = Image.open("examples/trump.jpg") iface = gr.Interface( fn=age_image, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose the current age"), gr.Slider(10, 90, value=70, step=1, label="Target age", info="Choose the desired age") ], outputs=gr.Image(type="pil", label="Aged Image"), examples=[ [example1, 14, 50], [example2, 74, 30], ], title="Face Aging Demo", description="Upload an image along with a source age approximation and a target age to generate an aged version of the face." ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7000)