Spaces:
Running
Running
import os | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from PIL import Image | |
import gradio as gr | |
# -------- CONFIG -------- | |
checkpoint_path = "age_prediction_model2.pth" # Just the model file name for Hugging Face Spaces | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# -------- SIMPLE CNN MODEL -------- | |
class SimpleCNN(nn.Module): | |
def __init__(self): | |
super(SimpleCNN, self).__init__() | |
self.features = nn.Sequential( | |
nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), | |
nn.MaxPool2d(2), # 64x64 | |
nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), | |
nn.MaxPool2d(2), # 32x32 | |
nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), | |
nn.MaxPool2d(2), # 16x16 | |
) | |
self.classifier = nn.Sequential( | |
nn.Flatten(), | |
nn.Linear(128 * 16 * 16, 256), nn.ReLU(), | |
nn.Linear(256, 1) # Output: age (regression) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.classifier(x) | |
return x | |
# -------- LOAD MODEL -------- | |
model = SimpleCNN().to(device) | |
# Check if the checkpoint file exists and load | |
if os.path.exists(checkpoint_path): | |
model.load_state_dict(torch.load(checkpoint_path, map_location=device)) # Load to the correct device | |
model.eval() # Set the model to evaluation mode | |
print(f"Model loaded from {checkpoint_path}") | |
else: | |
print(f"Error: Checkpoint file not found at {checkpoint_path}. Please check the path.") | |
# -------- PREPROCESSING -------- | |
transform = transforms.Compose([ | |
transforms.Resize((128, 128)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
]) | |
def predict_age(image: Image.Image) -> float: | |
image_tensor = transform(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output = model(image_tensor) | |
age = output.item() | |
return round(age, 2) | |
import gradio as gr | |
gr.Interface( | |
fn=predict_age, | |
inputs=gr.Image(type="pil", image_mode="RGB"), | |
outputs="number", # or "text" if your output is text-based | |
title="Age Prediction from Face", | |
description="Upload a face image and get the predicted age." | |
).launch() | |