multitask-demo / app.py
Abdu07's picture
Create app.py
5d0efae verified
raw
history blame
2.66 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import requests
from huggingface_hub import hf_hub_download
########################
# 1) Download & Load Model
########################
# Replace with your actual model repo on HF
repo_id = "Abdu07/multitask-model"
filename = "multitask_model.pth"
# Download the model file from the Hub
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
model = torch.load(model_path, map_location="cpu") # or map_location="cuda" if you prefer
model.eval()
########################
# 2) Define Label Mappings
########################
# For example, if your object labels are saved in code:
idx_to_obj_label = {
0: "cat",
1: "dog",
2: "car",
# ... fill in all your categories ...
}
bin_label_names = ["AI-Generated", "Real"] # Adjust if 0=AI, 1=Real
########################
# 3) Define Transforms
########################
# Match the transforms you used during validation
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
########################
# 4) Define the Inference Function
########################
def predict_image(img: Image.Image) -> str:
"""
Takes a PIL image, applies transforms, passes through the model,
and returns the combined prediction (object + AI/Real).
"""
# Convert to RGB just in case
img = img.convert("RGB")
# Apply transforms
img_t = val_transforms(img)
# Add batch dimension
img_t = img_t.unsqueeze(0)
with torch.no_grad():
obj_logits, bin_logits = model(img_t)
obj_pred = torch.argmax(obj_logits, dim=1).item()
bin_pred = torch.argmax(bin_logits, dim=1).item()
# Map predictions to labels
obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
bin_name = bin_label_names[bin_pred]
return f"Object: {obj_name} | Authenticity: {bin_name}"
########################
# 5) Build Gradio UI
########################
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Multi-Task Image Classifier",
description=(
"Upload an image to get two predictions: "
"1) The primary object (from pseudo-labeling), "
"2) Whether the image is AI-generated or real."
)
)
########################
# 6) Launch the App
########################
def main():
demo.launch(server_name="0.0.0.0", enable_queue=True)
if __name__ == "__main__":
main()