File size: 5,760 Bytes
b6af722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

from cosmos_predict1.auxiliary.guardrail.aegis.categories import UNSAFE_CATEGORIES
from cosmos_predict1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner
from cosmos_predict1.utils import log, misc

SAFE = misc.Color.green("SAFE")
UNSAFE = misc.Color.red("UNSAFE")


class Aegis(ContentSafetyGuardrail):
    def __init__(
        self,
        checkpoint_dir: str,
        device="cuda" if torch.cuda.is_available() else "cpu",
    ) -> None:
        self.checkpoint_dir = checkpoint_dir
        self.device = device
        self.dtype = torch.bfloat16

        base_model_id = "meta-llama/LlamaGuard-7b"
        aegis_adapter = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"
        base_model_dir = os.path.join(self.checkpoint_dir, base_model_id)
        aegis_adapter_dir = os.path.join(self.checkpoint_dir, aegis_adapter)

        base_model = AutoModelForCausalLM.from_pretrained(base_model_dir)
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_dir)
        self.model = PeftModel.from_pretrained(base_model, aegis_adapter_dir)

        self.model.to(self.device, dtype=self.dtype).eval()

    def get_moderation_prompt(self, user_prompt: str) -> str:
        """Create the moderation prompt for the Aegis model."""
        unsafe_categories = "".join([f"{category['instruction']}" for category in UNSAFE_CATEGORIES])
        full_prompt = (
            "[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.\n"
            "<BEGIN UNSAFE CONTENT CATEGORIES>\n"
            f"{unsafe_categories}\n"
            "<END UNSAFE CONTENT CATEGORIES>\n"
            "<BEGIN CONVERSATION>\n"
        )
        full_prompt += f"User: {user_prompt}\n\n"
        full_prompt += """<END CONVERSATION>

        Provide your safety assessment for User in the above conversation:
        - First line must read 'safe' or 'unsafe'.
        - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""

        return full_prompt

    def get_aegis_block_message(self, moderation_output: str) -> str:
        """Extract the blocked category and reason from the Aegis model output."""
        block_msg = "Prompt blocked by Aegis."
        try:
            # Extract blocked category and reason
            if len(moderation_output.split("\n")) >= 2:
                moderation_output_list = moderation_output.split("\n")
                blocked_category = int(moderation_output_list[1].lstrip("O")) - 1
                if 0 <= blocked_category < len(UNSAFE_CATEGORIES):
                    blocked_reason = UNSAFE_CATEGORIES[blocked_category]["blocked_reason"]
                    blocked_category_name = UNSAFE_CATEGORIES[blocked_category]["category"]
                    block_msg = f"{blocked_category_name}: {blocked_reason}"
        except Exception as e:
            log.warning(f"Unable to extract blocked category and reason from Aegis output: {e}")
        return block_msg

    def filter_aegis_output(self, prompt: str) -> tuple[bool, str]:
        """Filter the Aegis model output and return the safety status and message."""
        full_prompt = self.get_moderation_prompt(prompt)
        inputs = self.tokenizer([full_prompt], add_special_tokens=False, return_tensors="pt").to(self.device)
        output = self.model.generate(**inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id)
        prompt_len = inputs["input_ids"].shape[-1]
        moderation_output = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)

        if "unsafe" in moderation_output.lower():
            block_msg = self.get_aegis_block_message(moderation_output)
            return False, block_msg
        else:
            return True, ""

    def is_safe(self, prompt: str) -> tuple[bool, str]:
        """Check if the input prompt is safe according to the Aegis model."""
        try:
            return self.filter_aegis_output(prompt)
        except Exception as e:
            log.error(f"Unexpected error occurred when running Aegis guardrail: {e}")
            return True, "Unexpected error occurred when running Aegis guardrail."


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
    parser.add_argument(
        "--checkpoint_dir",
        type=str,
        help="Path to the Aegis checkpoint folder",
    )
    return parser.parse_args()


def main(args):
    aegis = Aegis(checkpoint_dir=args.checkpoint_dir)
    runner = GuardrailRunner(safety_models=[aegis])
    with misc.timer("aegis safety check"):
        safety, message = runner.run_safety_check(args.prompt)
    log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
    log.info(f"Message: {message}") if not safety else None


if __name__ == "__main__":
    args = parse_args()
    main(args)