Spaces:
Sleeping
Sleeping
File size: 2,382 Bytes
1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0de8628 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 0de8628 1dd6eaf 0de8628 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0d5c920 1dd6eaf 0de8628 1dd6eaf cc15c14 1dd6eaf 0d5c920 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
# -------- CONFIG --------
checkpoint_path = "age_prediction_model2.pth" # Just the model file name for Hugging Face Spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# -------- SIMPLE CNN MODEL --------
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 64x64
nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 32x32
nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2), # 16x16
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 16 * 16, 256), nn.ReLU(),
nn.Linear(256, 1) # Output: age (regression)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# -------- LOAD MODEL --------
model = SimpleCNN().to(device)
# Check if the checkpoint file exists and load
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path, map_location=device)) # Load to the correct device
model.eval() # Set the model to evaluation mode
print(f"Model loaded from {checkpoint_path}")
else:
print(f"Error: Checkpoint file not found at {checkpoint_path}. Please check the path.")
# -------- PREPROCESSING --------
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# -------- PREDICTION FUNCTION --------
def predict_age(image):
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
age = output.item() # Convert to a single scalar
return f"Predicted Age: {age:.2f}"
# -------- GRADIO INTERFACE --------
iface = gr.Interface(
fn=predict_age,
inputs=gr.Image(image_size=(128, 128), image_mode='RGB', source='upload'), # Corrected argument
outputs="text",
title="Age Prediction Model",
description="Upload an image to predict the age.",
live=True
)
iface.launch()
|