from llama_cpp import Llama from huggingface_hub import hf_hub_download # Download only the safety model shard from the model repo model_path = hf_hub_download( repo_id="Inventors-Hub/SwarmChat-models", repo_type="model", filename="llama-guard-3-8b-q4_k_m.gguf", ) # Load the Llama model for safety classification # model_path = r".\models\llama-guard-3-8b-q4_k_m.gguf" # llm = Llama(model_path=model_path, n_ctx=1024) llm = Llama( model_path=model_path, n_ctx=512, # down from 4096 low_vram=True, # llama.cpp low-vram mode f16_kv=True, # half-precision kv cache use_mmap=True, # mmap file use_mlock=False, ) def check_safety(text, checkbox): if checkbox: return "Safe" categories = { "S1": "Violent Crimes.", "S2": "Non-Violent Crimes.", "S3": "Sex-Related Crimes.", "S4": "Child Sexual Exploitation.", "S5": "Defamation.", "S6": "Specialized Advice.", "S7": "Privacy.", "S8": "Intellectual Property.", "S9": "Indiscriminate Weapons.", "S10": "Hate.", "S11": "Suicide & Self-Harm.", "S12": "Sexual Content.", "S13": "Elections.", "S14": "Code Interpreter Abuse.", } input_prompt = f""" <|begin_of_text|><|start_header_id|>user<|end_header_id|> Task: Determine if the following text contains unsafe content according to our safety policy. Classify the content into one or more of the categories below if unsafe. {', '.join([f'{key}: {value}' for key, value in categories.items()])} Text: {text} Provide your safety assessment: - Respond with "safe" if the text is appropriate. - If unsafe, respond with: "unsafe" followed by a comma-separated list of the violated categories (e.g., S1, S2). <|eot_id|> """ # Generate a response from the model output = llm(input_prompt, temperature=0) # Extract the model's response response = output.get("choices", [{}])[0].get("text", "").strip() print("\n\n", response, "\n\n") # Parse the response if response.lower().startswith("safe"): return "Safe" else: unsafe_categories = categories[response.split("unsafe", 1)[-1].strip()] return f"Unsafe: This prompt is categorized as '{unsafe_categories}'" # unsafe_categories = categories[response.split("unsafe", 1)[-1].strip()] # return f"Unsafe: This prompt categorized as '{unsafe_categories}'"