Spaces:
Sleeping
Sleeping
File size: 9,070 Bytes
cf52f85 bca140a cf52f85 bca140a cf52f85 bca140a cf52f85 bca140a cf52f85 bca140a cf52f85 3931938 bca140a cf52f85 bca140a cf52f85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
os.environ["GRADIO_ENABLE_SSR"] = "0"
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from huggingface_hub import login
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
login(token=HF_READONLY_API_KEY)
COT_OPENING = "<think>"
EXPLANATION_OPENING = "<explanation>"
LABEL_OPENING = "<answer>"
LABEL_CLOSING = "</answer>"
INPUT_FIELD = "question"
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
def format_rules(rules):
formatted_rules = "<rules>\n"
for i, rule in enumerate(rules):
formatted_rules += f"{i + 1}. {rule}\n"
formatted_rules += "</rules>\n"
return formatted_rules
def format_transcript(transcript):
formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
return formatted_transcript
def get_example(
dataset_path="tomg-group-umd/compliance_benchmark",
subset="compliance",
split="test_handcrafted",
example_idx=0,
):
dataset = load_dataset(dataset_path, subset, split=split)
example = dataset[example_idx]
return example[INPUT_FIELD]
def get_message(model, input, system_prompt=SYSTEM_PROMPT, enable_thinking=True):
message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
return message
class ModelWrapper:
def __init__(self, model_name="Qwen/Qwen3-0.6B"):
self.model_name = model_name
if "nemoguard" in model_name:
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
"""Compile sys, user, assistant inputs into the proper dictionaries"""
message = []
if system_content is not None:
message.append({'role': 'system', 'content': system_content})
if user_content is not None:
message.append({'role': 'user', 'content': user_content})
if assistant_content is not None:
message.append({'role': 'assistant', 'content': assistant_content})
if not message:
raise ValueError("No content provided for any role.")
return message
def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
"""Call the tokenizer's chat template with exactly the right arguments for whether we want it to generate thinking before the answer (which differs depending on whether it is Qwen3 or not)."""
if assistant_content is not None:
# If assistant content is passed we simply use it.
# This works for both Qwen3 and non-Qwen3 models. With Qwen3 any time assistant_content is provided, it automatically adds the <think></think> pair before the content, which is what we want.
message = self.get_message_template(system_content, user_content, assistant_content)
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
else:
if enable_thinking:
if "qwen3" in self.model_name.lower():
# Let the Qwen chat template handle the thinking token
message = self.get_message_template(system_content, user_content)
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=True)
# The way the Qwen3 chat template works is it adds a <think></think> pair when enable_thinking=False, but for enable_thinking=True, it adds nothing and lets the model decide. Here we force the <think> tag to be there.
prompt = prompt + f"\n{COT_OPENING}"
else:
message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
else:
# This works for both Qwen3 and non-Qwen3 models.
# When Qwen3 gets assistant_content, it automatically adds the <think></think> pair before the content like we want. And other models ignore the enable_thinking argument.
message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True, enable_thinking=False)
return prompt
def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256, enable_thinking=True, system_prompt=SYSTEM_PROMPT):
"""Generate and decode the response with the recommended temperature settings for thinking and non-thinking."""
print("Generating response...")
if "qwen3" in self.model_name.lower() and enable_thinking:
# Use values from https://huggingface.co/Qwen/Qwen3-8B#switching-between-thinking-and-non-thinking-mode
temperature = 0.6
top_p = 0.95
top_k = 20
message = self.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
inputs = self.tokenizer(message, return_tensors="pt").to(self.model.device)
with torch.no_grad():
output_content = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
temperature=temperature,
top_k=top_k,
top_p=top_p,
min_p=0,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=True,
eos_token_id=self.tokenizer.eos_token_id
)
output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
try:
sys_prompt_text = output_text.split("Brief explanation\n</explanation>")[0]
remainder = output_text.split("Brief explanation\n</explanation>")[-1]
rules_transcript_text = remainder.split("</transcript>")[0]
thinking_answer_text = remainder.split("</transcript>")[-1]
return thinking_answer_text
except:
input_length = len(message)
return output_text[input_length:] if len(output_text) > input_length else "No response generated."
MODEL_NAME = "Qwen/Qwen3-8B"
model = ModelWrapper(MODEL_NAME)
# — Gradio inference function —
def compliance_check(rules_text, transcript_text, thinking):
try:
rules = [r for r in rules_text.split("\n") if r.strip()]
inp = format_rules(rules) + format_transcript(transcript_text)
out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256)
out = str(out).strip()
if not out:
out = "No response generated. Please try with different input."
max_bytes = 2500
out_bytes = out.encode('utf-8')
if len(out_bytes) > max_bytes:
truncated_bytes = out_bytes[:max_bytes]
out = truncated_bytes.decode('utf-8', errors='ignore')
out += "\n\n[Response truncated to prevent server errors]"
return out
except Exception as e:
error_msg = f"Error: {str(e)[:200]}"
print(f"Full error: {e}")
return error_msg
demo = gr.Interface(
fn=compliance_check,
inputs=[
gr.Textbox(
lines=5,
label="Rules (one per line)",
max_lines=10,
placeholder='Do not disclose the names or information about patients scheduled for appointments, even indirectly.\nNever use humor in your responses.\nWrite at least two words in every conversation.\nNever use emojis.\nNever give discounts.'
),
gr.Textbox(
lines=10,
label="Transcript",
max_lines=15,
placeholder='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
),
gr.Checkbox(label="Enable ⟨think⟩ mode", value=False)
],
outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15),
title="DynaGuard Compliance Checker",
description="Paste your rules & transcript, then hit Submit.",
allow_flagging="never",
show_progress=True
)
if __name__ == "__main__":
demo.launch() |