File size: 2,588 Bytes
712d204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from llama_cpp import Llama
from huggingface_hub import hf_hub_download


# Download only the safety model shard from the model repo
model_path = hf_hub_download(
    repo_id="Inventors-Hub/SwarmChat-models",
    repo_type="model",
    filename="llama-guard-3-8b-q4_k_m.gguf",
)


# Load the Llama model for safety classification
# model_path = r".\models\llama-guard-3-8b-q4_k_m.gguf"


# llm = Llama(model_path=model_path, n_ctx=1024)
llm = Llama(
    model_path=model_path,
    n_ctx=512,            # down from 4096
    low_vram=True,         # llama.cpp low-vram mode
    f16_kv=True,           # half-precision kv cache
    use_mmap=True,         # mmap file
    use_mlock=False,
)


def check_safety(text, checkbox):

    if checkbox:
        return "Safe"

    categories = {
        "S1": "Violent Crimes.",
        "S2": "Non-Violent Crimes.",
        "S3": "Sex-Related Crimes.",
        "S4": "Child Sexual Exploitation.",
        "S5": "Defamation.",
        "S6": "Specialized Advice.",
        "S7": "Privacy.",
        "S8": "Intellectual Property.",
        "S9": "Indiscriminate Weapons.",
        "S10": "Hate.",
        "S11": "Suicide & Self-Harm.",
        "S12": "Sexual Content.",
        "S13": "Elections.",
        "S14": "Code Interpreter Abuse.",
    }

    input_prompt = f"""
    <|begin_of_text|><|start_header_id|>user<|end_header_id|>

    Task: Determine if the following text contains unsafe content according to our safety policy. Classify the content into one or more of the categories below if unsafe.

    <BEGIN UNSAFE CONTENT CATEGORIES>
    {', '.join([f'{key}: {value}' for key, value in categories.items()])}
    <END UNSAFE CONTENT CATEGORIES>

    Text: {text}

    Provide your safety assessment:
    - Respond with "safe" if the text is appropriate.
    - If unsafe, respond with: "unsafe" followed by a comma-separated list of the violated categories (e.g., S1, S2).

    <|eot_id|>
    """

    # Generate a response from the model
    output = llm(input_prompt, temperature=0)
    # Extract the model's response
    response = output.get("choices", [{}])[0].get("text", "").strip()
    print("\n\n", response, "\n\n")
    # Parse the response
    if response.lower().startswith("safe"):
        return "Safe"
    else:
        unsafe_categories = categories[response.split("unsafe", 1)[-1].strip()]
        return f"Unsafe: This prompt is categorized as '{unsafe_categories}'"

        # unsafe_categories = categories[response.split("unsafe", 1)[-1].strip()]
        # return f"Unsafe: This prompt categorized as '{unsafe_categories}'"