Hajorda commited on
Commit
f4a7c6b
Β·
verified Β·
1 Parent(s): 7c7c004

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -70
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py (for a Hugging Face Space using Gradio)
2
  import gradio as gr
3
  import torch
4
  import pytorch_lightning as pl
@@ -12,26 +12,28 @@ import pickle
12
  from PIL import Image
13
  import numpy as np
14
  import os
15
- import requests # For fetching funny cat GIFs
 
16
  from huggingface_hub import hf_hub_download
17
 
18
- # --- Re-use your model definition and loading functions ---
19
- # (This part would be similar to your inference.py)
20
-
21
- HF_USERNAME = "Hajorda" # Or the username of the model owner
22
- HF_MODEL_NAME = "keduClasifier"
23
  REPO_ID = f"{HF_USERNAME}/{HF_MODEL_NAME}"
24
 
 
25
  cfg_dict_for_inference = {
26
- 'model_name': 'swin_tiny_patch4_window7_224', # Match training
27
- 'dropout_backbone': 0.1, # Match training
28
- 'dropout_fc': 0.2, # Match training
29
  'img_size': (224, 224),
30
- 'num_classes': 37, # IMPORTANT: This must be correct for your trained model
31
  }
32
  cfg_inference = Box(cfg_dict_for_inference)
33
 
34
- class PetBreedModel(pl.LightningModule): # Paste your PetBreedModel class here
 
35
  def __init__(self, cfg: Box):
36
  super().__init__()
37
  self.cfg = cfg
@@ -39,67 +41,87 @@ class PetBreedModel(pl.LightningModule): # Paste your PetBreedModel class here
39
  self.cfg.model_name, pretrained=False, num_classes=0,
40
  in_chans=3, drop_rate=self.cfg.dropout_backbone
41
  )
42
- h, w = self.cfg.img_size
 
43
  dummy_input = torch.randn(1, 3, h, w)
44
- with torch.no_grad(): num_features = self.backbone(dummy_input).shape[-1]
 
45
  self.fc = nn.Sequential(
46
  nn.Linear(num_features, num_features // 2), nn.ReLU(),
47
  nn.Dropout(self.cfg.dropout_fc),
48
  nn.Linear(num_features // 2, self.cfg.num_classes)
49
  )
50
  def forward(self, x):
51
- features = self.backbone(x); output = self.fc(features)
 
52
  return output
53
 
 
54
  def load_model_from_hf_for_space(repo_id=REPO_ID, ckpt_filename="pytorch_model.ckpt"):
55
  model_path = hf_hub_download(repo_id=repo_id, filename=ckpt_filename)
56
- # Important: Ensure cfg_inference is correctly defined with num_classes
57
- if cfg_inference.num_classes is None:
58
  raise ValueError("num_classes must be set in cfg_inference to load the model for Gradio.")
 
59
  loaded_model = PetBreedModel.load_from_checkpoint(model_path, cfg=cfg_inference, strict=False)
60
  loaded_model.eval()
61
  return loaded_model
62
 
63
  def load_label_encoder_from_hf_for_space(repo_id=REPO_ID, le_filename="label_encoder.pkl"):
64
  le_path = hf_hub_download(repo_id=repo_id, filename=le_filename)
65
- with open(le_path, 'rb') as f: label_encoder = pickle.load(f)
 
66
  return label_encoder
67
 
68
- # Load model and encoder once when the app starts
69
- model = load_model_from_hf_for_space()
70
- label_encoder = load_label_encoder_from_hf_for_space()
71
- device = "cuda" if torch.cuda.is_available() else "cpu"
72
- model.to(device)
73
-
74
- # --- Funny elements ---
75
- funny_cat_keywords = ["funny cat", "silly cat", "cat meme", "derp cat"]
76
- GIPHY_API_KEY = "YOUR_GIPHY_API_KEY" # Optional: For more variety, get a Giphy API key
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  def get_funny_cat_gif(breed_name):
79
- try:
80
- # Use a public API if you don't have a Giphy key, or a simpler source
81
- # For example, a predefined list of GIFs
82
- predefined_gifs = {
83
- "abyssinian": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExaWN4bDNzNWVzM2VqNHE4Ym5zN2ZzZHF0Zzh0bGRqZzRjMnhsZW5pZCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/3oriO0OEd9QIDdllqo/giphy.gif",
84
- "siamese": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExa3g0dHZtZmRncWN0cnZkNnVnMGRtYjN2ajZ2d3o1cHZtaW50ZHQ5ayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/ICOgUNjpvO0PC/giphy.gif",
85
- "default": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExNWMwNnU4NW9nZTV5c3Z0eThsOHhsOWN0Nnh0a3VzbjFxeGU0bjFuNiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/BzyTuYCmvSORqs1ABM/giphy.gif"
86
- }
87
- return predefined_gifs.get(breed_name.lower().replace(" ", "_"), predefined_gifs["default"])
88
-
89
- # If using Giphy API:
90
- # search_term = f"{breed_name} {random.choice(funny_cat_keywords)}"
91
- # params = {'api_key': GIPHY_API_KEY, 'q': search_term, 'limit': 1, 'rating': 'g'}
92
- # response = requests.get("http://api.giphy.com/v1/gifs/search", params=params)
93
- # response.raise_for_status()
94
- # return response.json()['data'][0]['images']['original']['url']
95
- except Exception as e:
96
- print(f"Error fetching GIF: {e}")
97
- return predefined_gifs["default"] # Fallback
 
98
 
99
  # --- Gradio Interface Function ---
100
- def classify_cat_breed(image_input):
101
- # Gradio provides image as a NumPy array
102
- img_rgb = cv2.cvtColor(image_input, cv2.COLOR_BGR2RGB) # Ensure it's RGB if needed
 
 
 
103
 
104
  h, w = cfg_inference.img_size
105
  transforms_gradio = A.Compose([
@@ -112,50 +134,66 @@ def classify_cat_breed(image_input):
112
  with torch.no_grad():
113
  logits = model(input_tensor)
114
  probabilities = torch.softmax(logits, dim=1)
115
- # Get top N predictions if you want
116
- # top_probs, top_indices = torch.topk(probabilities, 3, dim=1)
117
-
118
- # For single prediction:
119
  confidence, predicted_idx = torch.max(probabilities, dim=1)
120
 
121
  predicted_breed_id = predicted_idx.item()
122
  predicted_breed_name = label_encoder.inverse_transform([predicted_breed_id])[0]
123
  conf_score = confidence.item()
124
 
125
- # Funny message and GIF
126
- funny_message = f"I'm {conf_score*100:.1f}% sure this adorable furball is a {predicted_breed_name}! What a purrfect specimen!"
127
- if conf_score < 0.7:
128
- funny_message += " ...Or maybe it's a new, super-rare breed only I can see. πŸ˜‰"
 
 
 
129
 
130
  gif_url = get_funny_cat_gif(predicted_breed_name)
131
 
132
- # Gradio expects a dictionary for multiple outputs if you name them
133
- # Or a tuple if you don't name them in gr.Interface outputs
134
  return (
135
- f"{predicted_breed_name} (Confidence: {conf_score*100:.2f}%)",
136
  funny_message,
137
- gif_url # Gradio can display images/GIFs from URLs
138
  )
139
 
140
  # --- Define the Gradio Interface ---
141
- title = "😸 Purrfect Breed Guesser 3000 😼"
142
- description = "Upload a picture of a cat, and I'll (hilariously) try to guess its breed! Powered by AI and a bit of cat-titude."
143
- article = "<p style='text-align: center'>Model based on Swin Transformer, fine-tuned on the Oxford-IIIT Pet Dataset. <a href='https://huggingface.co/YOUR_HF_USERNAME/my-pet-breed-classifier-swin-tiny' target='_blank'>Model Card</a></p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  iface = gr.Interface(
146
  fn=classify_cat_breed,
147
- inputs=gr.Image(type="numpy", label="Upload Cat Pic! πŸ“Έ"),
148
  outputs=[
149
- gr.Textbox(label="🧐 My Guess Is..."),
150
- gr.Textbox(label="πŸ’¬ My Deep Thoughts..."),
151
- gr.Image(type="filepath", label="πŸŽ‰ Celebration GIF! πŸŽ‰") # 'filepath' for URLs
152
  ],
153
  title=title,
154
  description=description,
155
  article=article,
156
- examples=[["example_cat1.jpg"], ["example_cat2.jpg"]], # Add paths to example images in your Space repo
157
- theme=gr.themes.Soft() # Or try other themes!
 
158
  )
159
 
160
  if __name__ == "__main__":
 
 
161
  iface.launch()
 
1
+ # app.py (for your Hugging Face Space/Model Repo: Hajorda/keduClassifier)
2
  import gradio as gr
3
  import torch
4
  import pytorch_lightning as pl
 
12
  from PIL import Image
13
  import numpy as np
14
  import os
15
+ # import requests # Commenting out as Giphy API key is not used by default
16
+ import random # For random choice of keywords if you enable Giphy later
17
  from huggingface_hub import hf_hub_download
18
 
19
+ # --- Model and Repository Configuration ---
20
+ # This should exactly match your model repository on Hugging Face
21
+ HF_USERNAME = "Hajorda"
22
+ HF_MODEL_NAME = "keduClassifier" # CORRECTED: Matches your repo name
 
23
  REPO_ID = f"{HF_USERNAME}/{HF_MODEL_NAME}"
24
 
25
+ # --- Inference Configuration ---
26
  cfg_dict_for_inference = {
27
+ 'model_name': 'swin_tiny_patch4_window7_224', # Should match your trained model
28
+ 'dropout_backbone': 0.1, # Should match your trained model
29
+ 'dropout_fc': 0.2, # Should match your trained model
30
  'img_size': (224, 224),
31
+ 'num_classes': 37, # This MUST match the number of classes your model was trained on
32
  }
33
  cfg_inference = Box(cfg_dict_for_inference)
34
 
35
+ # --- PyTorch Lightning Model Definition ---
36
+ class PetBreedModel(pl.LightningModule):
37
  def __init__(self, cfg: Box):
38
  super().__init__()
39
  self.cfg = cfg
 
41
  self.cfg.model_name, pretrained=False, num_classes=0,
42
  in_chans=3, drop_rate=self.cfg.dropout_backbone
43
  )
44
+ # Ensure img_size is a tuple for unpacking
45
+ h, w = self.cfg.img_size if isinstance(self.cfg.img_size, tuple) else (224, 224)
46
  dummy_input = torch.randn(1, 3, h, w)
47
+ with torch.no_grad():
48
+ num_features = self.backbone(dummy_input).shape[-1]
49
  self.fc = nn.Sequential(
50
  nn.Linear(num_features, num_features // 2), nn.ReLU(),
51
  nn.Dropout(self.cfg.dropout_fc),
52
  nn.Linear(num_features // 2, self.cfg.num_classes)
53
  )
54
  def forward(self, x):
55
+ features = self.backbone(x)
56
+ output = self.fc(features)
57
  return output
58
 
59
+ # --- Helper Functions to Load Assets from Hugging Face Hub ---
60
  def load_model_from_hf_for_space(repo_id=REPO_ID, ckpt_filename="pytorch_model.ckpt"):
61
  model_path = hf_hub_download(repo_id=repo_id, filename=ckpt_filename)
62
+ if cfg_inference.num_classes is None: # Should be set by cfg_dict_for_inference
 
63
  raise ValueError("num_classes must be set in cfg_inference to load the model for Gradio.")
64
+ # Pass the cfg for the model structure
65
  loaded_model = PetBreedModel.load_from_checkpoint(model_path, cfg=cfg_inference, strict=False)
66
  loaded_model.eval()
67
  return loaded_model
68
 
69
  def load_label_encoder_from_hf_for_space(repo_id=REPO_ID, le_filename="label_encoder.pkl"):
70
  le_path = hf_hub_download(repo_id=repo_id, filename=le_filename)
71
+ with open(le_path, 'rb') as f:
72
+ label_encoder = pickle.load(f)
73
  return label_encoder
74
 
75
+ # --- Load Model and Label Encoder (once at app startup) ---
76
+ print(f"Loading model and label encoder from repository: {REPO_ID}")
77
+ try:
78
+ model = load_model_from_hf_for_space()
79
+ label_encoder = load_label_encoder_from_hf_for_space()
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ model.to(device)
82
+ print(f"Model and label encoder loaded successfully. Using device: {device}")
83
+ except Exception as e:
84
+ print(f"Error loading model or label encoder: {e}")
85
+ # If loading fails, the Gradio app might not work.
86
+ # Consider how to handle this, e.g., display an error in the UI.
87
+ model = None
88
+ label_encoder = None
89
+ device = "cpu"
90
+
91
+
92
+ # --- Funny GIF Logic ---
93
+ # funny_cat_keywords = ["funny cat", "silly cat", "cat meme", "derp cat"]
94
+ # GIPHY_API_KEY = "YOUR_GIPHY_API_KEY" # Optional
95
 
96
  def get_funny_cat_gif(breed_name):
97
+ # Using a predefined list for simplicity and to avoid API key requirements
98
+ predefined_gifs = {
99
+ "abyssinian": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExaWN4bDNzNWVzM2VqNHE4Ym5zN2ZzZHF0Zzh0bGRqZzRjMnhsZW5pZCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/3oriO0OEd9QIDdllqo/giphy.gif",
100
+ "american bulldog": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExbHgzYXB6N3g5NThnaXU2eWR2aHljOXg3NjMzbGJwNmF6NmxkdXU2ayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/1simplexLKhMTqI/giphy.gif", # Example for a dog breed
101
+ "bengal": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExbnl0Z2J6cWtub29qdjFlajQ4ZXZ6czY2ZDY0cW53b3I2amI0OHhoYSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/BK1 SANT0sqq1q/giphy.gif",
102
+ "birman": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExZ3Q4NXZmMjQ1azE2dHZ2czZnNnBoNThkZ3FkY2Z0c3hqNjVqMTdhaSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/catdogcessing/giphy.gif",
103
+ "bombay": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExc3N5b2c3MmgwN3JzbjRkYmdocjdhcDc3ejExZGZqZmZtbDBxdXRrcSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/q1MeAPDDMb43K/giphy.gif",
104
+ "british shorthair": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTY3NG96bTc0bnFyOGNkaXBwcTYwdGZzZ3JwY2pscGNmbmZydG05eSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/Lq0h93752f6J9tij39/giphy.gif",
105
+ "egyptian mau": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExbjZ6dmJvaDhsb3N4ZXdkOXNrbzRkYnJmMHo3MnE2bWJocjU0Mm5jayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/3o7ZeLambpFh3TS2ZO/giphy.gif",
106
+ "maine coon": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExd3F6NWoyanFmY2xmcHBtMHRhMXAzaXZrYnJia3UxcDRtcXFsYjE2NSZlcD12MV9pbnRlcm5hbF_naWZfYnlfaWQmY3Q9Zw/MDrmyLuUh9A1a/giphy.gif",
107
+ "persian": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExYW12cDRuc3ZtZ2ZpN2Q2cjdwMHBmb2F3MzJ5d295dGRscG9hdmFpNiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/uE4gVmbjaZmmY/giphy.gif",
108
+ "ragdoll": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExczZqNWs2ZWU1ZTVobXVxdTZrN2hzcGZoaDVrYnNpZGF4a3FpM3N4aCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/ObTT5h01Xo43C/giphy.gif",
109
+ "russian blue": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExc3NqcHgzcnVldjA2MnQxc3oyZnp5a2R1eXZxY21hZTN4NHAwd2NyNyZlcD12MV9pbnRlcm5hbF_naWZfYnlfaWQmY3Q9Zw/114ZzmjHizvdsY/giphy.gif",
110
+ "siamese": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExa3g0dHZtZmRncWN0cnZkNnVnMGRtYjN2ajZ2d3o1cHZtaW50ZHQ5ayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/ICOgUNjpvO0PC/giphy.gif",
111
+ "sphynx": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExcXZjdzFybXh0ZW53OHI4ZWQxazNtb3N4dDNzOGJrdmZrdXFzbnUyZSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/mlvseq9yvZhba/giphy.gif",
112
+ "default": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExNWMwNnU4NW9nZTV5c3Z0eThsOHhsOWN0Nnh0a3VzbjFxeGU0bjFuNiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/BzyTuYCmvSORqs1ABM/giphy.gif"
113
+ }
114
+ # Normalize breed name for lookup
115
+ normalized_breed_name = breed_name.lower().replace(" ", "_").replace("-", "_")
116
+ return predefined_gifs.get(normalized_breed_name, predefined_gifs["default"])
117
 
118
  # --- Gradio Interface Function ---
119
+ def classify_cat_breed(image_input_bgr): # Gradio image is usually BGR numpy array
120
+ if model is None or label_encoder is None:
121
+ return ("Model not loaded. Please check logs.", "Error: Model components failed to load.", "")
122
+
123
+ # Convert BGR to RGB
124
+ img_rgb = cv2.cvtColor(image_input_bgr, cv2.COLOR_BGR2RGB)
125
 
126
  h, w = cfg_inference.img_size
127
  transforms_gradio = A.Compose([
 
134
  with torch.no_grad():
135
  logits = model(input_tensor)
136
  probabilities = torch.softmax(logits, dim=1)
 
 
 
 
137
  confidence, predicted_idx = torch.max(probabilities, dim=1)
138
 
139
  predicted_breed_id = predicted_idx.item()
140
  predicted_breed_name = label_encoder.inverse_transform([predicted_breed_id])[0]
141
  conf_score = confidence.item()
142
 
143
+ funny_message = f"My AI brain (all {conf_score*100:.1f}% of it that's sure) says this purrfect creature is a **{predicted_breed_name}**!"
144
+ if conf_score < 0.5:
145
+ funny_message += " ...Though, to be honest, it could also be a very fluffy potato. My circuits are confused! πŸ₯”"
146
+ elif conf_score < 0.8:
147
+ funny_message += " Pretty confident, but if it starts barking, don't blame me! 😜"
148
+ else:
149
+ funny_message += " Absolutely magnificent! A textbook example, if cats read textbooks. 🧐"
150
 
151
  gif_url = get_funny_cat_gif(predicted_breed_name)
152
 
 
 
153
  return (
154
+ f"**{predicted_breed_name.title()}** (Confidence: {conf_score*100:.2f}%)",
155
  funny_message,
156
+ gif_url
157
  )
158
 
159
  # --- Define the Gradio Interface ---
160
+ title = "😼 KEDU's Kompletely Kooky Kat (and K9?) Klassifier! 🐢"
161
+ description = (
162
+ "Upload a pic of your furry overlord (cat OR dog from the Oxford-IIIT set!), and I'll "
163
+ "attempt a hilariously 'accurate' breed guess. Powered by Swin Transformers and an "
164
+ "unhealthy obsession with pets. Results may vary, giggles guaranteed!"
165
+ )
166
+ # Corrected article link
167
+ article_link_href = f"https://huggingface.co/{REPO_ID}" # Uses the correctly defined REPO_ID
168
+ article = f"<p style='text-align: center'>Model based on Swin Transformer, fine-tuned on the Oxford-IIIT Pet Dataset. <a href='{article_link_href}' target='_blank'>Model Card & Files</a></p>"
169
+
170
+ # Add some example images to your repo and reference them here
171
+ # For example, if you add 'cat_example.jpg' and 'dog_example.jpg' to your HF repo
172
+ example_images = [
173
+ ["cat_example.jpg"], # You'll need to upload this image to your HF repo
174
+ ["dog_example.jpg"] # You'll need to upload this image to your HF repo
175
+ ]
176
+ # Check if example files exist, if not, provide placeholders or skip examples
177
+ # This check would ideally be done by trying to download them if they are remote URLs
178
+ # For local paths in a repo, Gradio handles it if the files are present.
179
 
180
  iface = gr.Interface(
181
  fn=classify_cat_breed,
182
+ inputs=gr.Image(type="numpy", label="Upload Your Pet's Most Glamorous Shot! πŸ“Έ"),
183
  outputs=[
184
+ gr.Textbox(label="🧐 The AI's Verdict Is... (Breed & Confidence)"),
185
+ gr.Markdown(label="πŸ’¬ AI's Deep (and Silly) Thoughts..."), # Markdown for bolding
186
+ gr.Image(type="filepath", label="πŸŽ‰ Celebration/Confusion GIF! πŸŽ‰")
187
  ],
188
  title=title,
189
  description=description,
190
  article=article,
191
+ # examples=example_images, # Uncomment if you add example images to your repo
192
+ theme=gr.themes.Monochrome(), # Trying a different theme
193
+ allow_flagging='never'
194
  )
195
 
196
  if __name__ == "__main__":
197
+ # When running locally (e.g., python app.py), this will launch the server.
198
+ # On Hugging Face Spaces, Spaces handles the launch.
199
  iface.launch()