File size: 1,778 Bytes
0fd4d4e
9d1a90e
0fd4d4e
 
 
 
bef7a72
 
0fd4d4e
bef7a72
030a144
 
 
 
 
9d1a90e
030a144
9d1a90e
0fd4d4e
 
 
 
9d1a90e
0fd4d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bd6b9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import time
import torch
from models import UNet
from test_functions import process_image
from PIL import Image
import gradio as gr

from huggingface_hub import hf_hub_download

MODEL_PATH = "model/best_unet_model.pth"

os.makedirs("model", exist_ok=True)

if not os.path.exists(MODEL_PATH):
    print("Starting model download at", time.strftime("%Y-%m-%d %H:%M:%S"))
    path = hf_hub_download(repo_id="Robys01/face-aging", filename="best_unet_model.pth", local_dir="model", cache_dir="model")
    print("Model downloaded to", MODEL_PATH, "at", time.strftime("%Y-%m-%d %H:%M:%S"))

model = UNet()  
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"), weights_only=False))
model.eval()
print("Model loaded at", time.strftime("%Y-%m-%d %H:%M:%S"))

def age_image(image: Image.Image, source_age: int, target_age: int) -> Image.Image:
    # Ensure the image is in RGB or grayscale; if not, convert it.
    if image.mode not in ["RGB", "L"]:
        print(f"Converting image from {image.mode} to RGB")
        image = image.convert("RGB")

    processed_image = process_image(model, image, source_age, target_age)
    return processed_image

iface = gr.Interface(
    fn=age_image,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose the current age"),
        gr.Slider(10, 90, value=70, step=1, label="Target age", info="Choose the desired age")
    ],
    outputs=gr.Image(type="pil", label="Aged Image"),
    title="Face Aging Demo",
    description="Upload an image along with a source age approximation and a target age to generate an aged version of the face."
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7000)