import os import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from PIL import Image import gradio as gr # -------- CONFIG -------- data_dir = "D:/Dataset/face_age" checkpoint_path = "D:/Dataset/age_prediction_model2.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # -------- SIMPLE CNN MODEL -------- class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 64x64 nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 32x32 nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 16x16 ) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(128 * 16 * 16, 256), nn.ReLU(), nn.Linear(256, 1) # Output: age (regression) ) def forward(self, x): x = self.features(x) x = self.classifier(x) return x # -------- LOAD MODEL -------- model = SimpleCNN().to(device) # Check if checkpoint exists before loading if os.path.exists(checkpoint_path): model.load_state_dict(torch.load(checkpoint_path)) model.eval() # Set the model to evaluation mode print(f"Model loaded from {checkpoint_path}") else: print(f"Error: Checkpoint file not found at {checkpoint_path}. Please check the path.") # -------- PREPROCESSING -------- transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) # -------- PREDICTION FUNCTION -------- def predict_age(image): image = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(image) age = output.item() # Convert to a single scalar return f"Predicted Age: {age:.2f}" # -------- GRADIO INTERFACE -------- iface = gr.Interface( fn=predict_age, inputs=gr.inputs.Image(shape=(128, 128), image_mode='RGB', source='upload'), outputs="text", title="Age Prediction Model", description="Upload an image to predict the age.", live=True ) iface.launch()