FaceAgePredict / app.py
MoinulwithAI's picture
Update app.py
c09e4cb verified
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
# -------- CONFIG --------
checkpoint_path = "age_prediction_model2.pth" # Just the model file name for Hugging Face Spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# -------- SIMPLE CNN MODEL --------
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 64x64
nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 32x32
nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 16x16
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 16 * 16, 256), nn.ReLU(),
nn.Linear(256, 1) # Output: age (regression)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# -------- LOAD MODEL --------
model = SimpleCNN().to(device)
# Check if the checkpoint file exists and load
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path, map_location=device)) # Load to the correct device
model.eval() # Set the model to evaluation mode
print(f"Model loaded from {checkpoint_path}")
else:
print(f"Error: Checkpoint file not found at {checkpoint_path}. Please check the path.")
# -------- PREPROCESSING --------
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def predict_age(image: Image.Image) -> float:
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
age = output.item()
return round(age, 2)
import gradio as gr
gr.Interface(
fn=predict_age,
inputs=gr.Image(type="pil", image_mode="RGB"),
outputs="number", # or "text" if your output is text-based
title="Age Prediction from Face",
description="Upload a face image and get the predicted age."
).launch()