File size: 2,363 Bytes
1dd6eaf
0d5c920
 
1dd6eaf
 
0d5c920
 
 
1dd6eaf
 
 
 
 
 
 
0d5c920
1dd6eaf
0d5c920
1dd6eaf
 
 
 
 
 
 
 
 
 
 
 
 
0d5c920
 
1dd6eaf
 
0d5c920
 
1dd6eaf
 
0d5c920
1dd6eaf
 
 
 
 
 
 
 
 
0d5c920
1dd6eaf
0d5c920
1dd6eaf
0d5c920
 
1dd6eaf
 
0d5c920
 
1dd6eaf
 
 
 
 
 
 
 
 
 
 
 
 
0d5c920
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import gradio as gr

# -------- CONFIG --------
data_dir = "D:/Dataset/face_age"
checkpoint_path = "D:/Dataset/age_prediction_model2.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -------- SIMPLE CNN MODEL --------
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 64x64
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 32x32
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 16x16
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 16 * 16, 256), nn.ReLU(),
            nn.Linear(256, 1)  # Output: age (regression)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# -------- LOAD MODEL --------
model = SimpleCNN().to(device)

# Check if checkpoint exists before loading
if os.path.exists(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path))
    model.eval()  # Set the model to evaluation mode
    print(f"Model loaded from {checkpoint_path}")
else:
    print(f"Error: Checkpoint file not found at {checkpoint_path}. Please check the path.")

# -------- PREPROCESSING --------
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# -------- PREDICTION FUNCTION --------
def predict_age(image):
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(image)
        age = output.item()  # Convert to a single scalar
    return f"Predicted Age: {age:.2f}"

# -------- GRADIO INTERFACE --------
iface = gr.Interface(
    fn=predict_age,
    inputs=gr.inputs.Image(shape=(128, 128), image_mode='RGB', source='upload'),
    outputs="text",
    title="Age Prediction Model",
    description="Upload an image to predict the age.",
    live=True
)

iface.launch()