Content_safety / app.py
Dileep7729's picture
Update app.py
bbfef86 verified
raw
history blame
2 kB
import os
import torch
from torchvision import transforms
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
import gradio as gr
# Step 1: Ensure Fine-Tuned Model is Available
fine_tuned_model_path = "fine-tuned-model"
if not os.path.exists(fine_tuned_model_path):
raise FileNotFoundError(
f"The fine-tuned model is missing. Ensure that the fine-tuned model files are available in the '{fine_tuned_model_path}' directory."
)
# Step 2: Load Fine-Tuned Model
print("Loading fine-tuned model...")
model = CLIPModel.from_pretrained(fine_tuned_model_path)
processor = CLIPProcessor.from_pretrained(fine_tuned_model_path)
print("Fine-tuned model loaded successfully.")
# Step 3: Define Gradio Inference Function
def classify_image(image, class_names):
# Split class names from comma-separated input
labels = [label.strip() for label in class_names.split(",") if label.strip()]
if not labels:
return {"Error": "Please enter at least one valid class name."}
# Process the image and labels
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
# Extract labels with their corresponding probabilities
result = {label: probs[0][i].item() for i, label in enumerate(labels)}
return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
# Step 4: Set Up Gradio Interface
iface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe")
],
outputs=gr.Label(num_top_classes=2),
title="Content Safety Classification",
description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.",
)
# Step 5: Launch Gradio Interface
if __name__ == "__main__":
iface.launch()