MoinulwithAI commited on
Commit
1dd6eaf
·
verified ·
1 Parent(s): 247c831

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -37
app.py CHANGED
@@ -1,57 +1,74 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchvision.transforms as transforms
5
  from PIL import Image
6
  import gradio as gr
7
 
8
- # ------------------- Model Definition -------------------
 
 
 
 
 
 
9
  class SimpleCNN(nn.Module):
10
- def __init__(self, num_classes=1):
11
  super(SimpleCNN, self).__init__()
12
- self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
13
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
14
- self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
15
- self.pool = nn.MaxPool2d(2, 2)
16
- self.fc1 = nn.Linear(128 * 28 * 28, 512)
17
- self.fc2 = nn.Linear(512, num_classes)
 
 
 
 
 
 
 
18
 
19
  def forward(self, x):
20
- x = self.pool(F.relu(self.conv1(x)))
21
- x = self.pool(F.relu(self.conv2(x)))
22
- x = self.pool(F.relu(self.conv3(x)))
23
- x = x.view(-1, 128 * 28 * 28)
24
- x = F.relu(self.fc1(x))
25
- x = self.fc2(x)
26
  return x
27
 
28
- # ------------------- Load Model -------------------
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- model = SimpleCNN()
31
- model.load_state_dict(torch.load("age_prediction_model1.pth", map_location=device))
32
- model.to(device)
33
- model.eval()
34
 
35
- # ------------------- Transform -------------------
 
 
 
 
 
 
 
 
36
  transform = transforms.Compose([
37
- transforms.Resize((224, 224)),
38
  transforms.ToTensor(),
39
- transforms.Normalize([0.485, 0.456, 0.406],
40
- [0.229, 0.224, 0.225])
41
  ])
42
 
43
- # ------------------- Prediction Function -------------------
44
- def predict(image):
45
  image = transform(image).unsqueeze(0).to(device)
46
  with torch.no_grad():
47
- output = model(image).squeeze().item()
48
- return f"Predicted Age: {round(output, 2)} years"
49
-
50
- # ------------------- Gradio Interface -------------------
51
- iface = gr.Interface(fn=predict,
52
- inputs=gr.Image(type="pil"),
53
- outputs="text",
54
- title="Face Age Prediction",
55
- description="Upload a face image and get a predicted age")
 
 
 
 
56
 
57
  iface.launch()
 
1
+ import os
2
  import torch
3
  import torch.nn as nn
4
+ from torch.utils.data import Dataset, DataLoader, random_split
5
+ from torchvision import transforms
6
  from PIL import Image
7
  import gradio as gr
8
 
9
+ # -------- CONFIG --------
10
+ data_dir = "D:/Dataset/face_age"
11
+ checkpoint_path = "D:/Dataset/age_prediction_model2.pth"
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"Using device: {device}")
14
+
15
+ # -------- SIMPLE CNN MODEL --------
16
  class SimpleCNN(nn.Module):
17
+ def __init__(self):
18
  super(SimpleCNN, self).__init__()
19
+ self.features = nn.Sequential(
20
+ nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(),
21
+ nn.MaxPool2d(2), # 64x64
22
+ nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(),
23
+ nn.MaxPool2d(2), # 32x32
24
+ nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
25
+ nn.MaxPool2d(2), # 16x16
26
+ )
27
+ self.classifier = nn.Sequential(
28
+ nn.Flatten(),
29
+ nn.Linear(128 * 16 * 16, 256), nn.ReLU(),
30
+ nn.Linear(256, 1) # Output: age (regression)
31
+ )
32
 
33
  def forward(self, x):
34
+ x = self.features(x)
35
+ x = self.classifier(x)
 
 
 
 
36
  return x
37
 
38
+ # -------- LOAD MODEL --------
39
+ model = SimpleCNN().to(device)
 
 
 
 
40
 
41
+ # Check if checkpoint exists before loading
42
+ if os.path.exists(checkpoint_path):
43
+ model.load_state_dict(torch.load(checkpoint_path))
44
+ model.eval() # Set the model to evaluation mode
45
+ print(f"Model loaded from {checkpoint_path}")
46
+ else:
47
+ print(f"Error: Checkpoint file not found at {checkpoint_path}. Please check the path.")
48
+
49
+ # -------- PREPROCESSING --------
50
  transform = transforms.Compose([
51
+ transforms.Resize((128, 128)),
52
  transforms.ToTensor(),
53
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
 
54
  ])
55
 
56
+ # -------- PREDICTION FUNCTION --------
57
+ def predict_age(image):
58
  image = transform(image).unsqueeze(0).to(device)
59
  with torch.no_grad():
60
+ output = model(image)
61
+ age = output.item() # Convert to a single scalar
62
+ return f"Predicted Age: {age:.2f}"
63
+
64
+ # -------- GRADIO INTERFACE --------
65
+ iface = gr.Interface(
66
+ fn=predict_age,
67
+ inputs=gr.inputs.Image(shape=(128, 128), image_mode='RGB', source='upload'),
68
+ outputs="text",
69
+ title="Age Prediction Model",
70
+ description="Upload an image to predict the age.",
71
+ live=True
72
+ )
73
 
74
  iface.launch()