File size: 2,868 Bytes
514b8b1
a41b014
4df31f3
514b8b1
4df31f3
 
bbfef86
4df31f3
 
 
 
bbfef86
4df31f3
a41b014
4df31f3
a41b014
4df31f3
 
 
 
 
a41b014
4df31f3
a41b014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514b8b1
4df31f3
514b8b1
 
a41b014
 
 
 
 
 
 
 
514b8b1
 
4df31f3
514b8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
bbfef86
4df31f3
a41b014
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import gradio as gr
from transformers import CLIPModel, CLIPProcessor

# Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"

print("Loading the fine-tuned model from Hugging Face Model Hub...")
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
processor = CLIPProcessor.from_pretrained(model_name)
print("Model loaded successfully.")

# Step 2: Define the Inference Function
def classify_image(image):
    """
    Classify an image as 'safe' or 'unsafe' with probabilities and subcategories.

    Args:
        image (PIL.Image.Image): The input image.
    
    Returns:
        dict: A dictionary containing main categories (safe/unsafe) and their probabilities.
    """
    # Define the predefined categories
    main_categories = ["safe", "unsafe"]
    safe_subcategories = ["retail product", "other safe content"]
    unsafe_subcategories = ["harmful", "violent", "sexual", "self harm"]

    # Process the image with the main categories
    main_inputs = processor(text=main_categories, images=image, return_tensors="pt", padding=True)
    main_outputs = model(**main_inputs)
    logits_per_image = main_outputs.logits_per_image  # Image-text similarity scores
    main_probs = logits_per_image.softmax(dim=1)  # Convert logits to probabilities

    # Determine the main category
    main_result = {main_categories[i]: main_probs[0][i].item() for i in range(len(main_categories))}
    main_category = max(main_result, key=main_result.get)  # Either "safe" or "unsafe"

    # Process the image with subcategories based on the main category
    subcategories = safe_subcategories if main_category == "safe" else unsafe_subcategories
    sub_inputs = processor(text=subcategories, images=image, return_tensors="pt", padding=True)
    sub_outputs = model(**sub_inputs)
    sub_logits = sub_outputs.logits_per_image
    sub_probs = sub_logits.softmax(dim=1)  # Convert logits to probabilities

    # Create a structured result
    result = {
        "Main Category": main_category,
        "Main Probabilities": main_result,
        "Subcategory Probabilities": {
            subcategories[i]: sub_probs[0][i].item() for i in range(len(subcategories))
        }
    }
    return result

# Step 3: Set Up Gradio Interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs="json",
    title="Enhanced Content Safety Classification",
    description=(
        "Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model. "
        "For 'safe', identify subcategories such as 'retail product'. "
        "For 'unsafe', identify subcategories such as 'harmful', 'violent', 'sexual', or 'self harm'."
    ),
)

# Step 4: Launch Gradio Interface
if __name__ == "__main__":
    iface.launch()