File size: 7,168 Bytes
a4e0d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# app.py (for a Hugging Face Space using Gradio)
import gradio as gr
import torch
import pytorch_lightning as pl
from timm import create_model
import torch.nn as nn
from box import Box
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import pickle
from PIL import Image
import numpy as np
import os
import requests # For fetching funny cat GIFs
from huggingface_hub import hf_hub_download

# --- Re-use your model definition and loading functions ---
# (This part would be similar to your inference.py)

HF_USERNAME = "Hajorda"  # Or the username of the model owner
HF_MODEL_NAME = "keduClasifier"
REPO_ID = f"{HF_USERNAME}/{HF_MODEL_NAME}"

cfg_dict_for_inference = {
    'model_name': 'swin_tiny_patch4_window7_224', # Match training
    'dropout_backbone': 0.1, # Match training
    'dropout_fc': 0.2,       # Match training
    'img_size': (224, 224),
    'num_classes': 37, # IMPORTANT: This must be correct for your trained model
}
cfg_inference = Box(cfg_dict_for_inference)

class PetBreedModel(pl.LightningModule): # Paste your PetBreedModel class here
    def __init__(self, cfg: Box):
        super().__init__()
        self.cfg = cfg
        self.backbone = create_model(
            self.cfg.model_name, pretrained=False, num_classes=0,
            in_chans=3, drop_rate=self.cfg.dropout_backbone
        )
        h, w = self.cfg.img_size
        dummy_input = torch.randn(1, 3, h, w)
        with torch.no_grad(): num_features = self.backbone(dummy_input).shape[-1]
        self.fc = nn.Sequential(
            nn.Linear(num_features, num_features // 2), nn.ReLU(),
            nn.Dropout(self.cfg.dropout_fc),
            nn.Linear(num_features // 2, self.cfg.num_classes)
        )
    def forward(self, x):
        features = self.backbone(x); output = self.fc(features)
        return output

def load_model_from_hf_for_space(repo_id=REPO_ID, ckpt_filename="pytorch_model.ckpt"):
    model_path = hf_hub_download(repo_id=repo_id, filename=ckpt_filename)
    # Important: Ensure cfg_inference is correctly defined with num_classes
    if cfg_inference.num_classes is None:
         raise ValueError("num_classes must be set in cfg_inference to load the model for Gradio.")
    loaded_model = PetBreedModel.load_from_checkpoint(model_path, cfg=cfg_inference, strict=False)
    loaded_model.eval()
    return loaded_model

def load_label_encoder_from_hf_for_space(repo_id=REPO_ID, le_filename="label_encoder.pkl"):
    le_path = hf_hub_download(repo_id=repo_id, filename=le_filename)
    with open(le_path, 'rb') as f: label_encoder = pickle.load(f)
    return label_encoder

# Load model and encoder once when the app starts
model = load_model_from_hf_for_space()
label_encoder = load_label_encoder_from_hf_for_space()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# --- Funny elements ---
funny_cat_keywords = ["funny cat", "silly cat", "cat meme", "derp cat"]
GIPHY_API_KEY = "YOUR_GIPHY_API_KEY" # Optional: For more variety, get a Giphy API key

def get_funny_cat_gif(breed_name):
    try:
        # Use a public API if you don't have a Giphy key, or a simpler source
        # For example, a predefined list of GIFs
        predefined_gifs = {
            "abyssinian": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExaWN4bDNzNWVzM2VqNHE4Ym5zN2ZzZHF0Zzh0bGRqZzRjMnhsZW5pZCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/3oriO0OEd9QIDdllqo/giphy.gif",
            "siamese": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExa3g0dHZtZmRncWN0cnZkNnVnMGRtYjN2ajZ2d3o1cHZtaW50ZHQ5ayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/ICOgUNjpvO0PC/giphy.gif",
            "default": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExNWMwNnU4NW9nZTV5c3Z0eThsOHhsOWN0Nnh0a3VzbjFxeGU0bjFuNiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/BzyTuYCmvSORqs1ABM/giphy.gif"
        }
        return predefined_gifs.get(breed_name.lower().replace(" ", "_"), predefined_gifs["default"])

        # If using Giphy API:
        # search_term = f"{breed_name} {random.choice(funny_cat_keywords)}"
        # params = {'api_key': GIPHY_API_KEY, 'q': search_term, 'limit': 1, 'rating': 'g'}
        # response = requests.get("http://api.giphy.com/v1/gifs/search", params=params)
        # response.raise_for_status()
        # return response.json()['data'][0]['images']['original']['url']
    except Exception as e:
        print(f"Error fetching GIF: {e}")
        return predefined_gifs["default"] # Fallback

# --- Gradio Interface Function ---
def classify_cat_breed(image_input):
    # Gradio provides image as a NumPy array
    img_rgb = cv2.cvtColor(image_input, cv2.COLOR_BGR2RGB) # Ensure it's RGB if needed

    h, w = cfg_inference.img_size
    transforms_gradio = A.Compose([
        A.Resize(height=h, width=w),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    input_tensor = transforms_gradio(image=img_rgb)['image'].unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(input_tensor)
        probabilities = torch.softmax(logits, dim=1)
        # Get top N predictions if you want
        # top_probs, top_indices = torch.topk(probabilities, 3, dim=1)

        # For single prediction:
        confidence, predicted_idx = torch.max(probabilities, dim=1)

    predicted_breed_id = predicted_idx.item()
    predicted_breed_name = label_encoder.inverse_transform([predicted_breed_id])[0]
    conf_score = confidence.item()

    # Funny message and GIF
    funny_message = f"I'm {conf_score*100:.1f}% sure this adorable furball is a {predicted_breed_name}! What a purrfect specimen!"
    if conf_score < 0.7:
        funny_message += " ...Or maybe it's a new, super-rare breed only I can see. πŸ˜‰"
    
    gif_url = get_funny_cat_gif(predicted_breed_name)

    # Gradio expects a dictionary for multiple outputs if you name them
    # Or a tuple if you don't name them in gr.Interface outputs
    return (
        f"{predicted_breed_name} (Confidence: {conf_score*100:.2f}%)",
        funny_message,
        gif_url # Gradio can display images/GIFs from URLs
    )

# --- Define the Gradio Interface ---
title = "😸 Purrfect Breed Guesser 3000 😼"
description = "Upload a picture of a cat, and I'll (hilariously) try to guess its breed! Powered by AI and a bit of cat-titude."
article = "<p style='text-align: center'>Model based on Swin Transformer, fine-tuned on the Oxford-IIIT Pet Dataset. <a href='https://huggingface.co/YOUR_HF_USERNAME/my-pet-breed-classifier-swin-tiny' target='_blank'>Model Card</a></p>"

iface = gr.Interface(
    fn=classify_cat_breed,
    inputs=gr.Image(type="numpy", label="Upload Cat Pic! πŸ“Έ"),
    outputs=[
        gr.Textbox(label="🧐 My Guess Is..."),
        gr.Textbox(label="πŸ’¬ My Deep Thoughts..."),
        gr.Image(type="filepath", label="πŸŽ‰ Celebration GIF! πŸŽ‰") # 'filepath' for URLs
    ],
    title=title,
    description=description,
    article=article,
    examples=[["example_cat1.jpg"], ["example_cat2.jpg"]], # Add paths to example images in your Space repo
    theme=gr.themes.Soft() # Or try other themes!
)

if __name__ == "__main__":
    iface.launch()