File size: 10,790 Bytes
f4a7c6b
a4e0d82
 
 
 
 
 
 
 
 
 
 
 
 
f4a7c6b
 
a4e0d82
 
f4a7c6b
 
 
 
a4e0d82
 
f4a7c6b
a4e0d82
f4a7c6b
 
 
a4e0d82
f4a7c6b
a4e0d82
 
 
f4a7c6b
 
a4e0d82
 
 
 
 
 
 
f4a7c6b
 
a4e0d82
f4a7c6b
 
a4e0d82
 
 
 
 
 
f4a7c6b
 
a4e0d82
 
f4a7c6b
a4e0d82
 
f4a7c6b
a4e0d82
f4a7c6b
a4e0d82
 
 
 
 
 
f4a7c6b
 
a4e0d82
 
f4a7c6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4e0d82
 
f4a7c6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4e0d82
 
f4a7c6b
 
 
 
 
 
a4e0d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4a7c6b
 
 
 
 
 
 
a4e0d82
 
 
 
f4a7c6b
a4e0d82
f4a7c6b
a4e0d82
 
 
f4a7c6b
 
 
 
 
 
 
 
 
 
 
 
 
1da39d3
 
 
f4a7c6b
 
 
 
a4e0d82
 
 
f4a7c6b
a4e0d82
f4a7c6b
 
 
a4e0d82
 
 
 
f4a7c6b
 
 
a4e0d82
 
 
f4a7c6b
 
a4e0d82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# app.py (for your Hugging Face Space/Model Repo: Hajorda/keduClassifier)
import gradio as gr
import torch
import pytorch_lightning as pl
from timm import create_model
import torch.nn as nn
from box import Box
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import pickle
from PIL import Image
import numpy as np
import os
# import requests # Commenting out as Giphy API key is not used by default
import random # For random choice of keywords if you enable Giphy later
from huggingface_hub import hf_hub_download

# --- Model and Repository Configuration ---
# This should exactly match your model repository on Hugging Face
HF_USERNAME = "Hajorda"
HF_MODEL_NAME = "keduClassifier" # CORRECTED: Matches your repo name
REPO_ID = f"{HF_USERNAME}/{HF_MODEL_NAME}"

# --- Inference Configuration ---
cfg_dict_for_inference = {
    'model_name': 'swin_tiny_patch4_window7_224', # Should match your trained model
    'dropout_backbone': 0.1, # Should match your trained model
    'dropout_fc': 0.2,       # Should match your trained model
    'img_size': (224, 224),
    'num_classes': 37, # This MUST match the number of classes your model was trained on
}
cfg_inference = Box(cfg_dict_for_inference)

# --- PyTorch Lightning Model Definition ---
class PetBreedModel(pl.LightningModule):
    def __init__(self, cfg: Box):
        super().__init__()
        self.cfg = cfg
        self.backbone = create_model(
            self.cfg.model_name, pretrained=False, num_classes=0,
            in_chans=3, drop_rate=self.cfg.dropout_backbone
        )
        # Ensure img_size is a tuple for unpacking
        h, w = self.cfg.img_size if isinstance(self.cfg.img_size, tuple) else (224, 224)
        dummy_input = torch.randn(1, 3, h, w)
        with torch.no_grad():
            num_features = self.backbone(dummy_input).shape[-1]
        self.fc = nn.Sequential(
            nn.Linear(num_features, num_features // 2), nn.ReLU(),
            nn.Dropout(self.cfg.dropout_fc),
            nn.Linear(num_features // 2, self.cfg.num_classes)
        )
    def forward(self, x):
        features = self.backbone(x)
        output = self.fc(features)
        return output

# --- Helper Functions to Load Assets from Hugging Face Hub ---
def load_model_from_hf_for_space(repo_id=REPO_ID, ckpt_filename="pytorch_model.ckpt"):
    model_path = hf_hub_download(repo_id=repo_id, filename=ckpt_filename)
    if cfg_inference.num_classes is None: # Should be set by cfg_dict_for_inference
         raise ValueError("num_classes must be set in cfg_inference to load the model for Gradio.")
    # Pass the cfg for the model structure
    loaded_model = PetBreedModel.load_from_checkpoint(model_path, cfg=cfg_inference, strict=False)
    loaded_model.eval()
    return loaded_model

def load_label_encoder_from_hf_for_space(repo_id=REPO_ID, le_filename="label_encoder.pkl"):
    le_path = hf_hub_download(repo_id=repo_id, filename=le_filename)
    with open(le_path, 'rb') as f:
        label_encoder = pickle.load(f)
    return label_encoder

# --- Load Model and Label Encoder (once at app startup) ---
print(f"Loading model and label encoder from repository: {REPO_ID}")
try:
    model = load_model_from_hf_for_space()
    label_encoder = load_label_encoder_from_hf_for_space()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    print(f"Model and label encoder loaded successfully. Using device: {device}")
except Exception as e:
    print(f"Error loading model or label encoder: {e}")
    # If loading fails, the Gradio app might not work.
    # Consider how to handle this, e.g., display an error in the UI.
    model = None
    label_encoder = None
    device = "cpu"


# --- Funny GIF Logic ---
# funny_cat_keywords = ["funny cat", "silly cat", "cat meme", "derp cat"]
# GIPHY_API_KEY = "YOUR_GIPHY_API_KEY" # Optional

def get_funny_cat_gif(breed_name):
    # Using a predefined list for simplicity and to avoid API key requirements
    predefined_gifs = {
        "abyssinian": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExaWN4bDNzNWVzM2VqNHE4Ym5zN2ZzZHF0Zzh0bGRqZzRjMnhsZW5pZCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/3oriO0OEd9QIDdllqo/giphy.gif",
        "american bulldog": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExbHgzYXB6N3g5NThnaXU2eWR2aHljOXg3NjMzbGJwNmF6NmxkdXU2ayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/1simplexLKhMTqI/giphy.gif", # Example for a dog breed
        "bengal": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExbnl0Z2J6cWtub29qdjFlajQ4ZXZ6czY2ZDY0cW53b3I2amI0OHhoYSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/BK1 SANT0sqq1q/giphy.gif",
        "birman": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExZ3Q4NXZmMjQ1azE2dHZ2czZnNnBoNThkZ3FkY2Z0c3hqNjVqMTdhaSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/catdogcessing/giphy.gif",
        "bombay": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExc3N5b2c3MmgwN3JzbjRkYmdocjdhcDc3ejExZGZqZmZtbDBxdXRrcSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/q1MeAPDDMb43K/giphy.gif",
        "british shorthair": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTY3NG96bTc0bnFyOGNkaXBwcTYwdGZzZ3JwY2pscGNmbmZydG05eSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/Lq0h93752f6J9tij39/giphy.gif",
        "egyptian mau": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExbjZ6dmJvaDhsb3N4ZXdkOXNrbzRkYnJmMHo3MnE2bWJocjU0Mm5jayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/3o7ZeLambpFh3TS2ZO/giphy.gif",
        "maine coon": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExd3F6NWoyanFmY2xmcHBtMHRhMXAzaXZrYnJia3UxcDRtcXFsYjE2NSZlcD12MV9pbnRlcm5hbF_naWZfYnlfaWQmY3Q9Zw/MDrmyLuUh9A1a/giphy.gif",
        "persian": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExYW12cDRuc3ZtZ2ZpN2Q2cjdwMHBmb2F3MzJ5d295dGRscG9hdmFpNiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/uE4gVmbjaZmmY/giphy.gif",
        "ragdoll": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExczZqNWs2ZWU1ZTVobXVxdTZrN2hzcGZoaDVrYnNpZGF4a3FpM3N4aCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/ObTT5h01Xo43C/giphy.gif",
        "russian blue": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExc3NqcHgzcnVldjA2MnQxc3oyZnp5a2R1eXZxY21hZTN4NHAwd2NyNyZlcD12MV9pbnRlcm5hbF_naWZfYnlfaWQmY3Q9Zw/114ZzmjHizvdsY/giphy.gif",
        "siamese": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExa3g0dHZtZmRncWN0cnZkNnVnMGRtYjN2ajZ2d3o1cHZtaW50ZHQ5ayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/ICOgUNjpvO0PC/giphy.gif",
        "sphynx": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExcXZjdzFybXh0ZW53OHI4ZWQxazNtb3N4dDNzOGJrdmZrdXFzbnUyZSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/mlvseq9yvZhba/giphy.gif",
        "default": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExNWMwNnU4NW9nZTV5c3Z0eThsOHhsOWN0Nnh0a3VzbjFxeGU0bjFuNiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/BzyTuYCmvSORqs1ABM/giphy.gif"
    }
    # Normalize breed name for lookup
    normalized_breed_name = breed_name.lower().replace(" ", "_").replace("-", "_")
    return predefined_gifs.get(normalized_breed_name, predefined_gifs["default"])

# --- Gradio Interface Function ---
def classify_cat_breed(image_input_bgr): # Gradio image is usually BGR numpy array
    if model is None or label_encoder is None:
        return ("Model not loaded. Please check logs.", "Error: Model components failed to load.", "")

    # Convert BGR to RGB
    img_rgb = cv2.cvtColor(image_input_bgr, cv2.COLOR_BGR2RGB)

    h, w = cfg_inference.img_size
    transforms_gradio = A.Compose([
        A.Resize(height=h, width=w),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    input_tensor = transforms_gradio(image=img_rgb)['image'].unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(input_tensor)
        probabilities = torch.softmax(logits, dim=1)
        confidence, predicted_idx = torch.max(probabilities, dim=1)

    predicted_breed_id = predicted_idx.item()
    predicted_breed_name = label_encoder.inverse_transform([predicted_breed_id])[0]
    conf_score = confidence.item()

    funny_message = f"My AI brain (all {conf_score*100:.1f}% of it that's sure) says this purrfect creature is a **{predicted_breed_name}**!"
    if conf_score < 0.5:
        funny_message += " ...Though, to be honest, it could also be a very fluffy potato. My circuits are confused! πŸ₯”"
    elif conf_score < 0.8:
        funny_message += " Pretty confident, but if it starts barking, don't blame me! 😜"
    else:
        funny_message += " Absolutely magnificent! A textbook example, if cats read textbooks. 🧐"
    
    gif_url = get_funny_cat_gif(predicted_breed_name)

    return (
        f"**{predicted_breed_name.title()}** (Confidence: {conf_score*100:.2f}%)",
        funny_message,
        gif_url
    )

# --- Define the Gradio Interface ---
title = "😼 KEDU's Kompletely Kooky Kat (and K9?) Klassifier! 🐢"
description = (
    "Upload a pic of your furry overlord (cat OR dog from the Oxford-IIIT set!), and I'll "
    "attempt a hilariously 'accurate' breed guess. Powered by Swin Transformers and an "
    "unhealthy obsession with pets. Results may vary, giggles guaranteed!"
)
# Corrected article link
article_link_href = f"https://huggingface.co/{REPO_ID}" # Uses the correctly defined REPO_ID
article = f"<p style='text-align: center'>Model based on Swin Transformer, fine-tuned on the Oxford-IIIT Pet Dataset. <a href='{article_link_href}' target='_blank'>Model Card & Files</a></p>"

# Add some example images to your repo and reference them here
# For example, if you add 'cat_example.jpg' and 'dog_example.jpg' to your HF repo
example_images = [
    ["cat1.webp"], # You'll need to upload this image to your HF repo
    ["cat2.webp"],  # You'll need to upload this image to your HF repo
    ["cat3.webp"],
]
# Check if example files exist, if not, provide placeholders or skip examples
# This check would ideally be done by trying to download them if they are remote URLs
# For local paths in a repo, Gradio handles it if the files are present.

iface = gr.Interface(
    fn=classify_cat_breed,
    inputs=gr.Image(type="numpy", label="Upload Your Pet's Most Glamorous Shot! πŸ“Έ"),
    outputs=[
        gr.Textbox(label="🧐 The AI's Verdict Is... (Breed & Confidence)"),
        gr.Markdown(label="πŸ’¬ AI's Deep (and Silly) Thoughts..."), # Markdown for bolding
        gr.Image(type="filepath", label="πŸŽ‰ Celebration/Confusion GIF! πŸŽ‰")
    ],
    title=title,
    description=description,
    article=article,
    # examples=example_images, # Uncomment if you add example images to your repo
    theme=gr.themes.Monochrome(), # Trying a different theme
    allow_flagging='never'
)

if __name__ == "__main__":
    # When running locally (e.g., python app.py), this will launch the server.
    # On Hugging Face Spaces, Spaces handles the launch.
    iface.launch()