Spaces:
Sleeping
Sleeping
File size: 7,168 Bytes
a4e0d82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# app.py (for a Hugging Face Space using Gradio)
import gradio as gr
import torch
import pytorch_lightning as pl
from timm import create_model
import torch.nn as nn
from box import Box
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import pickle
from PIL import Image
import numpy as np
import os
import requests # For fetching funny cat GIFs
from huggingface_hub import hf_hub_download
# --- Re-use your model definition and loading functions ---
# (This part would be similar to your inference.py)
HF_USERNAME = "Hajorda" # Or the username of the model owner
HF_MODEL_NAME = "keduClasifier"
REPO_ID = f"{HF_USERNAME}/{HF_MODEL_NAME}"
cfg_dict_for_inference = {
'model_name': 'swin_tiny_patch4_window7_224', # Match training
'dropout_backbone': 0.1, # Match training
'dropout_fc': 0.2, # Match training
'img_size': (224, 224),
'num_classes': 37, # IMPORTANT: This must be correct for your trained model
}
cfg_inference = Box(cfg_dict_for_inference)
class PetBreedModel(pl.LightningModule): # Paste your PetBreedModel class here
def __init__(self, cfg: Box):
super().__init__()
self.cfg = cfg
self.backbone = create_model(
self.cfg.model_name, pretrained=False, num_classes=0,
in_chans=3, drop_rate=self.cfg.dropout_backbone
)
h, w = self.cfg.img_size
dummy_input = torch.randn(1, 3, h, w)
with torch.no_grad(): num_features = self.backbone(dummy_input).shape[-1]
self.fc = nn.Sequential(
nn.Linear(num_features, num_features // 2), nn.ReLU(),
nn.Dropout(self.cfg.dropout_fc),
nn.Linear(num_features // 2, self.cfg.num_classes)
)
def forward(self, x):
features = self.backbone(x); output = self.fc(features)
return output
def load_model_from_hf_for_space(repo_id=REPO_ID, ckpt_filename="pytorch_model.ckpt"):
model_path = hf_hub_download(repo_id=repo_id, filename=ckpt_filename)
# Important: Ensure cfg_inference is correctly defined with num_classes
if cfg_inference.num_classes is None:
raise ValueError("num_classes must be set in cfg_inference to load the model for Gradio.")
loaded_model = PetBreedModel.load_from_checkpoint(model_path, cfg=cfg_inference, strict=False)
loaded_model.eval()
return loaded_model
def load_label_encoder_from_hf_for_space(repo_id=REPO_ID, le_filename="label_encoder.pkl"):
le_path = hf_hub_download(repo_id=repo_id, filename=le_filename)
with open(le_path, 'rb') as f: label_encoder = pickle.load(f)
return label_encoder
# Load model and encoder once when the app starts
model = load_model_from_hf_for_space()
label_encoder = load_label_encoder_from_hf_for_space()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# --- Funny elements ---
funny_cat_keywords = ["funny cat", "silly cat", "cat meme", "derp cat"]
GIPHY_API_KEY = "YOUR_GIPHY_API_KEY" # Optional: For more variety, get a Giphy API key
def get_funny_cat_gif(breed_name):
try:
# Use a public API if you don't have a Giphy key, or a simpler source
# For example, a predefined list of GIFs
predefined_gifs = {
"abyssinian": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExaWN4bDNzNWVzM2VqNHE4Ym5zN2ZzZHF0Zzh0bGRqZzRjMnhsZW5pZCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/3oriO0OEd9QIDdllqo/giphy.gif",
"siamese": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExa3g0dHZtZmRncWN0cnZkNnVnMGRtYjN2ajZ2d3o1cHZtaW50ZHQ5ayZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/ICOgUNjpvO0PC/giphy.gif",
"default": "https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExNWMwNnU4NW9nZTV5c3Z0eThsOHhsOWN0Nnh0a3VzbjFxeGU0bjFuNiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/BzyTuYCmvSORqs1ABM/giphy.gif"
}
return predefined_gifs.get(breed_name.lower().replace(" ", "_"), predefined_gifs["default"])
# If using Giphy API:
# search_term = f"{breed_name} {random.choice(funny_cat_keywords)}"
# params = {'api_key': GIPHY_API_KEY, 'q': search_term, 'limit': 1, 'rating': 'g'}
# response = requests.get("http://api.giphy.com/v1/gifs/search", params=params)
# response.raise_for_status()
# return response.json()['data'][0]['images']['original']['url']
except Exception as e:
print(f"Error fetching GIF: {e}")
return predefined_gifs["default"] # Fallback
# --- Gradio Interface Function ---
def classify_cat_breed(image_input):
# Gradio provides image as a NumPy array
img_rgb = cv2.cvtColor(image_input, cv2.COLOR_BGR2RGB) # Ensure it's RGB if needed
h, w = cfg_inference.img_size
transforms_gradio = A.Compose([
A.Resize(height=h, width=w),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
input_tensor = transforms_gradio(image=img_rgb)['image'].unsqueeze(0).to(device)
with torch.no_grad():
logits = model(input_tensor)
probabilities = torch.softmax(logits, dim=1)
# Get top N predictions if you want
# top_probs, top_indices = torch.topk(probabilities, 3, dim=1)
# For single prediction:
confidence, predicted_idx = torch.max(probabilities, dim=1)
predicted_breed_id = predicted_idx.item()
predicted_breed_name = label_encoder.inverse_transform([predicted_breed_id])[0]
conf_score = confidence.item()
# Funny message and GIF
funny_message = f"I'm {conf_score*100:.1f}% sure this adorable furball is a {predicted_breed_name}! What a purrfect specimen!"
if conf_score < 0.7:
funny_message += " ...Or maybe it's a new, super-rare breed only I can see. π"
gif_url = get_funny_cat_gif(predicted_breed_name)
# Gradio expects a dictionary for multiple outputs if you name them
# Or a tuple if you don't name them in gr.Interface outputs
return (
f"{predicted_breed_name} (Confidence: {conf_score*100:.2f}%)",
funny_message,
gif_url # Gradio can display images/GIFs from URLs
)
# --- Define the Gradio Interface ---
title = "πΈ Purrfect Breed Guesser 3000 πΌ"
description = "Upload a picture of a cat, and I'll (hilariously) try to guess its breed! Powered by AI and a bit of cat-titude."
article = "<p style='text-align: center'>Model based on Swin Transformer, fine-tuned on the Oxford-IIIT Pet Dataset. <a href='https://huggingface.co/YOUR_HF_USERNAME/my-pet-breed-classifier-swin-tiny' target='_blank'>Model Card</a></p>"
iface = gr.Interface(
fn=classify_cat_breed,
inputs=gr.Image(type="numpy", label="Upload Cat Pic! πΈ"),
outputs=[
gr.Textbox(label="π§ My Guess Is..."),
gr.Textbox(label="π¬ My Deep Thoughts..."),
gr.Image(type="filepath", label="π Celebration GIF! π") # 'filepath' for URLs
],
title=title,
description=description,
article=article,
examples=[["example_cat1.jpg"], ["example_cat2.jpg"]], # Add paths to example images in your Space repo
theme=gr.themes.Soft() # Or try other themes!
)
if __name__ == "__main__":
iface.launch() |