Spaces:
Running
Running
File size: 2,354 Bytes
1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 30463cc 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 609997a 024b978 1dd6eaf 024b978 1dd6eaf 0d5c920 609997a 024b978 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import gradio as gr
# -------- CONFIG --------
data_dir = "D:/Dataset/face_age"
checkpoint_path = "age_prediction_model2.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# -------- SIMPLE CNN MODEL --------
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 64x64
nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 32x32
nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 16x16
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 16 * 16, 256), nn.ReLU(),
nn.Linear(256, 1) # Output: age (regression)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# -------- LOAD MODEL --------
model = SimpleCNN().to(device)
# Check if checkpoint exists before loading
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path))
model.eval() # Set the model to evaluation mode
print(f"Model loaded from {checkpoint_path}")
else:
print(f"Error: Checkpoint file not found at {checkpoint_path}. Please check the path.")
# -------- PREPROCESSING --------
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# -------- PREDICTION FUNCTION --------
def predict_age(image):
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
age = output.item() # Convert to a single scalar
return f"Predicted Age: {age:.2f}"
# Update the gr.Image initialization
iface = gr.Interface(
fn=predict_age,
inputs=gr.Image(image_size=(128, 128), image_mode='RGB', source='upload'),
outputs="text",
title="Age Prediction Model",
description="Upload an image to predict the age.",
live=True
)
iface.launch()
|