MoinulwithAI commited on
Commit
0de8628
·
verified ·
1 Parent(s): 024b978

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -1,14 +1,12 @@
1
  import os
2
  import torch
3
  import torch.nn as nn
4
- from torch.utils.data import Dataset, DataLoader, random_split
5
  from torchvision import transforms
6
  from PIL import Image
7
  import gradio as gr
8
 
9
  # -------- CONFIG --------
10
- data_dir = "D:/Dataset/face_age"
11
- checkpoint_path = "age_prediction_model2.pth"
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print(f"Using device: {device}")
14
 
@@ -38,9 +36,9 @@ class SimpleCNN(nn.Module):
38
  # -------- LOAD MODEL --------
39
  model = SimpleCNN().to(device)
40
 
41
- # Check if checkpoint exists before loading
42
  if os.path.exists(checkpoint_path):
43
- model.load_state_dict(torch.load(checkpoint_path))
44
  model.eval() # Set the model to evaluation mode
45
  print(f"Model loaded from {checkpoint_path}")
46
  else:
@@ -61,12 +59,10 @@ def predict_age(image):
61
  age = output.item() # Convert to a single scalar
62
  return f"Predicted Age: {age:.2f}"
63
 
64
-
65
-
66
- # Update the gr.Image initialization
67
  iface = gr.Interface(
68
  fn=predict_age,
69
- inputs=gr.Image(image_size=(128, 128), image_mode='RGB', source='upload'),
70
  outputs="text",
71
  title="Age Prediction Model",
72
  description="Upload an image to predict the age.",
@@ -74,5 +70,3 @@ iface = gr.Interface(
74
  )
75
 
76
  iface.launch()
77
-
78
-
 
1
  import os
2
  import torch
3
  import torch.nn as nn
 
4
  from torchvision import transforms
5
  from PIL import Image
6
  import gradio as gr
7
 
8
  # -------- CONFIG --------
9
+ checkpoint_path = "age_prediction_model2.pth" # Just the model file name for Hugging Face Spaces
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  print(f"Using device: {device}")
12
 
 
36
  # -------- LOAD MODEL --------
37
  model = SimpleCNN().to(device)
38
 
39
+ # Check if the checkpoint file exists and load
40
  if os.path.exists(checkpoint_path):
41
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device)) # Load to the correct device
42
  model.eval() # Set the model to evaluation mode
43
  print(f"Model loaded from {checkpoint_path}")
44
  else:
 
59
  age = output.item() # Convert to a single scalar
60
  return f"Predicted Age: {age:.2f}"
61
 
62
+ # -------- GRADIO INTERFACE --------
 
 
63
  iface = gr.Interface(
64
  fn=predict_age,
65
+ inputs=gr.Image(shape=(128, 128), image_mode='RGB', source='upload'), # Updated input format
66
  outputs="text",
67
  title="Age Prediction Model",
68
  description="Upload an image to predict the age.",
 
70
  )
71
 
72
  iface.launch()