File size: 7,717 Bytes
27a346a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
lionguard2.py
"""

import torch
import torch.nn as nn

CATEGORIES = {
    "binary": ["binary"],
    "hateful": ["hateful_l1", "hateful_l2"],
    "insults": ["insults"],
    "sexual": [
        "sexual_l1",
        "sexual_l2",
    ],
    "physical_violence": ["physical_violence"],
    "self_harm": ["self_harm_l1", "self_harm_l2"],
    "all_other_misconduct": [
        "all_other_misconduct_l1",
        "all_other_misconduct_l2",
    ],
}

INPUT_DIMENSION = 3072  # length of OpenAI embeddings


class LionGuard2(nn.Module):
    def __init__(
        self,
        input_dim=INPUT_DIMENSION,
        label_names=CATEGORIES.keys(),
        categories=CATEGORIES,
    ):
        """
        LionGuard2 is a localised content moderation model that flags whether text violates the following categories:

        1. `hateful`: Text that discriminates, criticizes, insults, denounces, or dehumanizes a person or group on the basis of a protected identity.

        There are two sub-categories for the `hateful` category:
        a. `level_1_discriminatory`: Text that contains derogatory or generalized negative statements targeting a protected group.
        b. `level_2_hate_speech`: Text that explicitly calls for harm or violence against a protected group; or language praising or justifying violence against them.

        2. `insults`: Text that insults demeans, humiliates, mocks, or belittles a person or group **without** referencing a legally protected trait.
        For example, this includes personal attacks on attributes such as someone’s appearance, intellect, behavior, or other non-protected characteristics.

        3. `sexual`: Text that depicts or indicates sexual interest, activity, or arousal, using direct or indirect references to body parts, sexual acts, or physical traits.
        This includes sexual content that may be inappropriate for certain audiences.

        There are two sub-categories for the `sexual` category:
        a. `level_1_not_appropriate_for_minors`: Text that contains mild-to-moderate sexual content that is generally adult-oriented or potentially unsuitable for those under 16.
            May include matter-of-fact discussions about sex, sexuality, or sexual preferences.
        b. `level_2_not_appropriate_for_all_ages`: Text that contains content aimed at adults and considered explicit, graphic, or otherwise inappropriate for a broad audience.
            May include explicit descriptions of sexual acts, detailed sexual fantasies, or highly sexualized content.

        4. `physical_violence`: Text that includes glorification of violence or threats to inflict physical harm or injury on a person, group, or entity.

        5. `self_harm`: Text that promotes, suggests, or expresses intent to self-harm or commit suicide.

        There are two sub-categories for the `self_harm` category:
        a. `level_1_self_harm_intent`: Text that expresses suicidal thoughts or self-harm intention; or content encouraging someone to self-harm.
        b. `level_2_self_harm_action`: Text that describes or indicates ongoing or imminent self-harm behavior.

        6. `all_other_misconduct`: This is a catch-all category for any other unsafe text that does not fit into the other categories.
        It includes text that seeks or provides information about engaging in misconduct, wrongdoing, or criminal activity, or that threatens to harm,
        defraud, or exploit others. This includes facilitating illegal acts (under Singapore law) or other forms of socially harmful activity.

        There are two sub-categories for the `all_other_misconduct` category:
        a. `level_1_not_socially_accepted`: Text that advocates or instructs on unethical/immoral activities that may not necessarily be illegal but are socially condemned.
        b. `level_2_illegal_activities`: Text that seeks or provides instructions to carry out clearly illegal activities or serious wrongdoing; includes credible threats of severe harm.

        Lastly, there is an additional `binary` category (#7) which flags whether the text is unsafe in general.

        The model takes in as input text, after it has been encoded with OpenAI's `text-embedding-3-small` model.

        The model outputs the probabilities of each category being true.

        ================================

        Args:
            input_dim: The dimension of the input embeddings. This defaults to 3072, which is the dimension of the embeddings from OpenAI's `text-embedding-3-small` model. This should not be changed.
            label_names: The names of the labels. This defaults to the keys of the CATEGORIES dictionary. This should not be changed.
            categories: The categories of the labels. This defaults to the CATEGORIES dictionary. This should not be changed.

        Returns:
            A LionGuard2 model.
        """
        super(LionGuard2, self).__init__()
        self.label_names = label_names
        self.n_outputs = len(label_names)
        self.categories = categories

        # Shared layers
        self.shared_layers = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

        # Output heads for each label
        self.output_heads = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(128, 32),
                    nn.ReLU(),
                    nn.Linear(32, 2),  # 2 thresholds for ordinal classification
                    nn.Sigmoid(),
                )
                for _ in range(self.n_outputs)
            ]
        )

    def forward(self, x):
        # Pass through shared layers
        h = self.shared_layers(x)
        # Pass through each output head
        return [head(h) for head in self.output_heads]

    def predict(self, embeddings):
        """
        Predict the probabilities of each label being true.

        Args:
            embeddings: A numpy array of embeddings (N * INPUT_DIMENSION)

        Returns:
            A dictionary of probabilities.
        """
        # Convert input to PyTorch tensor if not already
        if not isinstance(embeddings, torch.Tensor):
            x = torch.tensor(embeddings, dtype=torch.float32)
        else:
            x = embeddings

        # Pass through model
        with torch.no_grad():
            outputs = self.forward(x)

        # Stack outputs into a single tensor
        raw_predictions = torch.stack(outputs)  # SIZE:

        # Extract and format probabilities from raw predictions
        output = {}
        for i, main_cat in enumerate(self.label_names):
            sub_categories = self.categories[main_cat]
            for j, sub_cat in enumerate(sub_categories):
                # j=0 uses P(y>0)
                # j=1 uses P(y>1) if L2 category exists
                output[sub_cat] = raw_predictions[i, :, j]

            # Post processing step:
            # If L2 category exists, and P(L2) > P(L1),
            # Set both P(L1) and P(L2) to their average to maintain ordinal consistency
            if len(sub_categories) > 1:
                l1 = output[sub_categories[0]]
                l2 = output[sub_categories[1]]

                # Update probabilities on samples where P(L2) > P(L1)
                mask = l2 > l1
                mean_prob = (l1 + l2) / 2
                l1[mask] = mean_prob[mask]
                l2[mask] = mean_prob[mask]
                output[sub_categories[0]] = l1
                output[sub_categories[1]] = l2

        for key, value in output.items():
            output[key] = value.numpy().tolist()
        return output