Spaces:
Sleeping
Sleeping
import gradio as gr | |
import gradio as gr | |
from transformers import CLIPModel, CLIPProcessor | |
# Step 1: Load Fine-Tuned Model from Hugging Face Model Hub | |
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned" | |
print("Loading the fine-tuned model from Hugging Face Model Hub...") | |
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True) | |
processor = CLIPProcessor.from_pretrained(model_name) | |
print("Model loaded successfully.") | |
# Step 2: Define the Inference Function | |
def classify_image(image): | |
""" | |
Classify an image as 'safe' or 'unsafe' with probabilities and subcategories. | |
Args: | |
image (PIL.Image.Image): The input image. | |
Returns: | |
dict: A dictionary containing main categories (safe/unsafe) and their probabilities. | |
""" | |
# Define the predefined categories | |
main_categories = ["safe", "unsafe"] | |
safe_subcategories = ["retail product", "other safe content"] | |
unsafe_subcategories = ["harmful", "violent", "sexual", "self harm"] | |
# Process the image with the main categories | |
main_inputs = processor(text=main_categories, images=image, return_tensors="pt", padding=True) | |
main_outputs = model(**main_inputs) | |
logits_per_image = main_outputs.logits_per_image # Image-text similarity scores | |
main_probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities | |
# Determine the main category | |
main_result = {main_categories[i]: main_probs[0][i].item() for i in range(len(main_categories))} | |
main_category = max(main_result, key=main_result.get) # Either "safe" or "unsafe" | |
# Process the image with subcategories based on the main category | |
subcategories = safe_subcategories if main_category == "safe" else unsafe_subcategories | |
sub_inputs = processor(text=subcategories, images=image, return_tensors="pt", padding=True) | |
sub_outputs = model(**sub_inputs) | |
sub_logits = sub_outputs.logits_per_image | |
sub_probs = sub_logits.softmax(dim=1) # Convert logits to probabilities | |
# Create a structured result | |
result = { | |
"Main Category": main_category, | |
"Main Probabilities": main_result, | |
"Subcategory Probabilities": { | |
subcategories[i]: sub_probs[0][i].item() for i in range(len(subcategories)) | |
} | |
} | |
return result | |
# Step 3: Set Up Gradio Interface | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(type="pil"), | |
outputs="json", | |
title="Enhanced Content Safety Classification", | |
description=( | |
"Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model. " | |
"For 'safe', identify subcategories such as 'retail product'. " | |
"For 'unsafe', identify subcategories such as 'harmful', 'violent', 'sexual', or 'self harm'." | |
), | |
) | |
# Step 4: Launch Gradio Interface | |
if __name__ == "__main__": | |
iface.launch() | |