import os import torch from torchvision import transforms from PIL import Image from transformers import CLIPModel, CLIPProcessor import gradio as gr # Step 1: Ensure Fine-Tuned Model is Available fine_tuned_model_path = "fine-tuned-model" if not os.path.exists(fine_tuned_model_path): raise FileNotFoundError( f"The fine-tuned model is missing. Ensure that the fine-tuned model files are available in the '{fine_tuned_model_path}' directory." ) # Step 2: Load Fine-Tuned Model print("Loading fine-tuned model...") model = CLIPModel.from_pretrained(fine_tuned_model_path) processor = CLIPProcessor.from_pretrained(fine_tuned_model_path) print("Fine-tuned model loaded successfully.") # Step 3: Define Gradio Inference Function def classify_image(image, class_names): # Split class names from comma-separated input labels = [label.strip() for label in class_names.split(",") if label.strip()] if not labels: return {"Error": "Please enter at least one valid class name."} # Process the image and labels inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities # Extract labels with their corresponding probabilities result = {label: probs[0][i].item() for i, label in enumerate(labels)} return dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) # Step 4: Set Up Gradio Interface iface = gr.Interface( fn=classify_image, inputs=[ gr.Image(type="pil"), gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe") ], outputs=gr.Label(num_top_classes=2), title="Content Safety Classification", description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.", ) # Step 5: Launch Gradio Interface if __name__ == "__main__": iface.launch()