Spaces:
Sleeping
Sleeping
File size: 10,790 Bytes
f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b 1da39d3 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 f4a7c6b a4e0d82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
# app.py (for your Hugging Face Space/Model Repo: Hajorda/keduClassifier)
import gradio as gr
import torch
import pytorch_lightning as pl
from timm import create_model
import torch.nn as nn
from box import Box
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import pickle
from PIL import Image
import numpy as np
import os
# import requests # Commenting out as Giphy API key is not used by default
import random # For random choice of keywords if you enable Giphy later
from huggingface_hub import hf_hub_download
# --- Model and Repository Configuration ---
# This should exactly match your model repository on Hugging Face
HF_USERNAME = "Hajorda"
HF_MODEL_NAME = "keduClassifier" # CORRECTED: Matches your repo name
REPO_ID = f"{HF_USERNAME}/{HF_MODEL_NAME}"
# --- Inference Configuration ---
cfg_dict_for_inference = {
'model_name': 'swin_tiny_patch4_window7_224', # Should match your trained model
'dropout_backbone': 0.1, # Should match your trained model
'dropout_fc': 0.2, # Should match your trained model
'img_size': (224, 224),
'num_classes': 37, # This MUST match the number of classes your model was trained on
}
cfg_inference = Box(cfg_dict_for_inference)
# --- PyTorch Lightning Model Definition ---
class PetBreedModel(pl.LightningModule):
def __init__(self, cfg: Box):
super().__init__()
self.cfg = cfg
self.backbone = create_model(
self.cfg.model_name, pretrained=False, num_classes=0,
in_chans=3, drop_rate=self.cfg.dropout_backbone
)
# Ensure img_size is a tuple for unpacking
h, w = self.cfg.img_size if isinstance(self.cfg.img_size, tuple) else (224, 224)
dummy_input = torch.randn(1, 3, h, w)
with torch.no_grad():
num_features = self.backbone(dummy_input).shape[-1]
self.fc = nn.Sequential(
nn.Linear(num_features, num_features // 2), nn.ReLU(),
nn.Dropout(self.cfg.dropout_fc),
nn.Linear(num_features // 2, self.cfg.num_classes)
)
def forward(self, x):
features = self.backbone(x)
output = self.fc(features)
return output
# --- Helper Functions to Load Assets from Hugging Face Hub ---
def load_model_from_hf_for_space(repo_id=REPO_ID, ckpt_filename="pytorch_model.ckpt"):
model_path = hf_hub_download(repo_id=repo_id, filename=ckpt_filename)
if cfg_inference.num_classes is None: # Should be set by cfg_dict_for_inference
raise ValueError("num_classes must be set in cfg_inference to load the model for Gradio.")
# Pass the cfg for the model structure
loaded_model = PetBreedModel.load_from_checkpoint(model_path, cfg=cfg_inference, strict=False)
loaded_model.eval()
return loaded_model
def load_label_encoder_from_hf_for_space(repo_id=REPO_ID, le_filename="label_encoder.pkl"):
le_path = hf_hub_download(repo_id=repo_id, filename=le_filename)
with open(le_path, 'rb') as f:
label_encoder = pickle.load(f)
return label_encoder
# --- Load Model and Label Encoder (once at app startup) ---
print(f"Loading model and label encoder from repository: {REPO_ID}")
try:
model = load_model_from_hf_for_space()
label_encoder = load_label_encoder_from_hf_for_space()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model and label encoder loaded successfully. Using device: {device}")
except Exception as e:
print(f"Error loading model or label encoder: {e}")
# If loading fails, the Gradio app might not work.
# Consider how to handle this, e.g., display an error in the UI.
model = None
label_encoder = None
device = "cpu"
# --- Funny GIF Logic ---
# funny_cat_keywords = ["funny cat", "silly cat", "cat meme", "derp cat"]
# GIPHY_API_KEY = "YOUR_GIPHY_API_KEY" # Optional
def get_funny_cat_gif(breed_name):
# Using a predefined list for simplicity and to avoid API key requirements
predefined_gifs = {
"abyssinian": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExaWN4bDNzNWVzM2VqNHE4Ym5zN2ZzZHF0Zzh0bGRqZzRjMnhsZW5pZCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/3oriO0OEd9QIDdllqo/giphy.gif",
"american bulldog": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExbHgzYXB6N3g5NThnaXU2eWR2aHljOXg3NjMzbGJwNmF6NmxkdXU2ayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/1simplexLKhMTqI/giphy.gif", # Example for a dog breed
"bengal": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExbnl0Z2J6cWtub29qdjFlajQ4ZXZ6czY2ZDY0cW53b3I2amI0OHhoYSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/BK1 SANT0sqq1q/giphy.gif",
"birman": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExZ3Q4NXZmMjQ1azE2dHZ2czZnNnBoNThkZ3FkY2Z0c3hqNjVqMTdhaSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/catdogcessing/giphy.gif",
"bombay": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExc3N5b2c3MmgwN3JzbjRkYmdocjdhcDc3ejExZGZqZmZtbDBxdXRrcSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/q1MeAPDDMb43K/giphy.gif",
"british shorthair": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTY3NG96bTc0bnFyOGNkaXBwcTYwdGZzZ3JwY2pscGNmbmZydG05eSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/Lq0h93752f6J9tij39/giphy.gif",
"egyptian mau": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExbjZ6dmJvaDhsb3N4ZXdkOXNrbzRkYnJmMHo3MnE2bWJocjU0Mm5jayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/3o7ZeLambpFh3TS2ZO/giphy.gif",
"maine coon": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExd3F6NWoyanFmY2xmcHBtMHRhMXAzaXZrYnJia3UxcDRtcXFsYjE2NSZlcD12MV9pbnRlcm5hbF_naWZfYnlfaWQmY3Q9Zw/MDrmyLuUh9A1a/giphy.gif",
"persian": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExYW12cDRuc3ZtZ2ZpN2Q2cjdwMHBmb2F3MzJ5d295dGRscG9hdmFpNiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/uE4gVmbjaZmmY/giphy.gif",
"ragdoll": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExczZqNWs2ZWU1ZTVobXVxdTZrN2hzcGZoaDVrYnNpZGF4a3FpM3N4aCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/ObTT5h01Xo43C/giphy.gif",
"russian blue": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExc3NqcHgzcnVldjA2MnQxc3oyZnp5a2R1eXZxY21hZTN4NHAwd2NyNyZlcD12MV9pbnRlcm5hbF_naWZfYnlfaWQmY3Q9Zw/114ZzmjHizvdsY/giphy.gif",
"siamese": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExa3g0dHZtZmRncWN0cnZkNnVnMGRtYjN2ajZ2d3o1cHZtaW50ZHQ5ayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/ICOgUNjpvO0PC/giphy.gif",
"sphynx": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExcXZjdzFybXh0ZW53OHI4ZWQxazNtb3N4dDNzOGJrdmZrdXFzbnUyZSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/mlvseq9yvZhba/giphy.gif",
"default": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExNWMwNnU4NW9nZTV5c3Z0eThsOHhsOWN0Nnh0a3VzbjFxeGU0bjFuNiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/BzyTuYCmvSORqs1ABM/giphy.gif"
}
# Normalize breed name for lookup
normalized_breed_name = breed_name.lower().replace(" ", "_").replace("-", "_")
return predefined_gifs.get(normalized_breed_name, predefined_gifs["default"])
# --- Gradio Interface Function ---
def classify_cat_breed(image_input_bgr): # Gradio image is usually BGR numpy array
if model is None or label_encoder is None:
return ("Model not loaded. Please check logs.", "Error: Model components failed to load.", "")
# Convert BGR to RGB
img_rgb = cv2.cvtColor(image_input_bgr, cv2.COLOR_BGR2RGB)
h, w = cfg_inference.img_size
transforms_gradio = A.Compose([
A.Resize(height=h, width=w),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
input_tensor = transforms_gradio(image=img_rgb)['image'].unsqueeze(0).to(device)
with torch.no_grad():
logits = model(input_tensor)
probabilities = torch.softmax(logits, dim=1)
confidence, predicted_idx = torch.max(probabilities, dim=1)
predicted_breed_id = predicted_idx.item()
predicted_breed_name = label_encoder.inverse_transform([predicted_breed_id])[0]
conf_score = confidence.item()
funny_message = f"My AI brain (all {conf_score*100:.1f}% of it that's sure) says this purrfect creature is a **{predicted_breed_name}**!"
if conf_score < 0.5:
funny_message += " ...Though, to be honest, it could also be a very fluffy potato. My circuits are confused! π₯"
elif conf_score < 0.8:
funny_message += " Pretty confident, but if it starts barking, don't blame me! π"
else:
funny_message += " Absolutely magnificent! A textbook example, if cats read textbooks. π§"
gif_url = get_funny_cat_gif(predicted_breed_name)
return (
f"**{predicted_breed_name.title()}** (Confidence: {conf_score*100:.2f}%)",
funny_message,
gif_url
)
# --- Define the Gradio Interface ---
title = "πΌ KEDU's Kompletely Kooky Kat (and K9?) Klassifier! πΆ"
description = (
"Upload a pic of your furry overlord (cat OR dog from the Oxford-IIIT set!), and I'll "
"attempt a hilariously 'accurate' breed guess. Powered by Swin Transformers and an "
"unhealthy obsession with pets. Results may vary, giggles guaranteed!"
)
# Corrected article link
article_link_href = f"https://huggingface.co/{REPO_ID}" # Uses the correctly defined REPO_ID
article = f"<p style='text-align: center'>Model based on Swin Transformer, fine-tuned on the Oxford-IIIT Pet Dataset. <a href='{article_link_href}' target='_blank'>Model Card & Files</a></p>"
# Add some example images to your repo and reference them here
# For example, if you add 'cat_example.jpg' and 'dog_example.jpg' to your HF repo
example_images = [
["cat1.webp"], # You'll need to upload this image to your HF repo
["cat2.webp"], # You'll need to upload this image to your HF repo
["cat3.webp"],
]
# Check if example files exist, if not, provide placeholders or skip examples
# This check would ideally be done by trying to download them if they are remote URLs
# For local paths in a repo, Gradio handles it if the files are present.
iface = gr.Interface(
fn=classify_cat_breed,
inputs=gr.Image(type="numpy", label="Upload Your Pet's Most Glamorous Shot! πΈ"),
outputs=[
gr.Textbox(label="π§ The AI's Verdict Is... (Breed & Confidence)"),
gr.Markdown(label="π¬ AI's Deep (and Silly) Thoughts..."), # Markdown for bolding
gr.Image(type="filepath", label="π Celebration/Confusion GIF! π")
],
title=title,
description=description,
article=article,
# examples=example_images, # Uncomment if you add example images to your repo
theme=gr.themes.Monochrome(), # Trying a different theme
allow_flagging='never'
)
if __name__ == "__main__":
# When running locally (e.g., python app.py), this will launch the server.
# On Hugging Face Spaces, Spaces handles the launch.
iface.launch() |