Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
from transformers import CLIPModel, CLIPProcessor | |
import gradio as gr | |
# Step 1: Ensure Fine-Tuned Model is Available | |
fine_tuned_model_path = "fine-tuned-model" | |
if not os.path.exists(fine_tuned_model_path): | |
raise FileNotFoundError( | |
f"The fine-tuned model is missing. Ensure that the fine-tuned model files are available in the '{fine_tuned_model_path}' directory." | |
) | |
# Step 2: Load Fine-Tuned Model | |
print("Loading fine-tuned model...") | |
model = CLIPModel.from_pretrained(fine_tuned_model_path) | |
processor = CLIPProcessor.from_pretrained(fine_tuned_model_path) | |
print("Fine-tuned model loaded successfully.") | |
# Step 3: Define Gradio Inference Function | |
def classify_image(image, class_names): | |
# Split class names from comma-separated input | |
labels = [label.strip() for label in class_names.split(",") if label.strip()] | |
if not labels: | |
return {"Error": "Please enter at least one valid class name."} | |
# Process the image and labels | |
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities | |
# Extract labels with their corresponding probabilities | |
result = {label: probs[0][i].item() for i, label in enumerate(labels)} | |
return dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) | |
# Step 4: Set Up Gradio Interface | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe") | |
], | |
outputs=gr.Label(num_top_classes=2), | |
title="Content Safety Classification", | |
description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.", | |
) | |
# Step 5: Launch Gradio Interface | |
if __name__ == "__main__": | |
iface.launch() | |