File size: 5,504 Bytes
ab7c317 54e69b4 ab7c317 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import random
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import unicodedata
import re
import gradio as gr
from pprint import pprint
MODEL_ID = "livekit/turn-detector"
REVISION_ID = "v0.3.0-intl"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, revision=REVISION_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
revision=REVISION_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.eval()
EN_THRESHOLD = 0.0049
START_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|im_start|>')
EOU_TOKEN_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
NEWLINE_TOKEN_ID = tokenizer.convert_tokens_to_ids('\n')
USER_TOKEN_IDS = (
tokenizer.convert_tokens_to_ids('user'),
tokenizer.convert_tokens_to_ids('<|user|>')
)
SPECIAL_TOKENS = set([
NEWLINE_TOKEN_ID,
START_TOKEN_ID,
tokenizer.convert_tokens_to_ids('user'),
tokenizer.convert_tokens_to_ids('assistant'),
])
CONTROL_TOKS = ['<|im_start|>', '<|im_end|>', 'user', 'assistant', '\n']
def normalize_text(text):
text = unicodedata.normalize("NFKC", text.lower())
text = ''.join(
ch for ch in text
if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"])
)
text = re.sub(r'\s+', ' ', text).strip()
return text
def format_input(text):
if '<|im_start|>' not in text:
# assume single user turn
text = {'role': 'user', 'content': normalize_text(text)}
text = tokenizer.apply_chat_template(
[text],
tokenize=False,
add_generation_prompt=True
)
return text
def log_odds(p, eps=0):
return np.log(p /(1 - p + eps))
def make_pred_mask(input_ids):
user_mask = [False] * len(input_ids)
is_user_role = False
for i in range(len(input_ids)-1):
if input_ids[i] == START_TOKEN_ID:
is_user_role = input_ids[i+1] in USER_TOKEN_IDS
if is_user_role and (input_ids[i] not in SPECIAL_TOKENS):
user_mask[i] = True
else:
user_mask[i] = False
return user_mask
def predict_eou(text):
text = format_input(text)
with torch.no_grad():
with torch.amp.autocast(model.device.type):
inputs = tokenizer.encode(
text,
add_special_tokens=False,
return_tensors="pt"
).to(model.device)
outputs = model(inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
probs = probs.cpu().float().numpy()[:, :, EOU_TOKEN_ID].flatten()
input_ids = inputs.cpu().int().flatten().numpy()
mask = np.array(make_pred_mask(input_ids))
probs[~mask] = np.nan
tokens = [tokenizer.decode(id) for id in input_ids]
res = {'token':tokens,'pred':probs}
return pd.DataFrame(res)
def make_styled_df(df, thresh=EN_THRESHOLD, cmap="coolwarm"):
EPS = 1e-12
df = df.copy()
df = df[~df.token.isin(CONTROL_TOKS)]
df.token = df.token.replace({"\n": "⏎"," ": "␠",})
df['log_odds'] = (
df.pred.fillna(thresh)
.add(EPS)
.apply(log_odds).sub(log_odds(thresh))
.mask(df.pred.isna())
)
df['Prob(EoT) as %'] = df.pred.mul(100).fillna(0).astype(int)
vmin, vmax = df.log_odds.min(), df.log_odds.max()
vmax_abs = max(abs(vmin), abs(vmax)) * 1.5
fmt = (
df.drop(columns=['pred'])
.style
.bar(
subset=['log_odds'],
align="zero",
vmin=-vmax_abs,
vmax=vmax_abs,
cmap=cmap,
height=70,
width=100,
)
.text_gradient(subset=['log_odds'], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs)
.format(na_rep='', precision=1, subset=['log_odds'])
.format("{:3d}", subset=['Prob(EoT) as %'])
.hide(axis='index')
)
return fmt.to_html()
def generate_highlighted_text(text, threshold=EN_THRESHOLD):
eps = 1e-12
if not text:
return []
df = predict_eou(text)
df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"})
df = df[~df.token.isin(CONTROL_TOKS)]
df['score'] = (
df.pred.fillna(threshold)
.add(eps)
.apply(log_odds).sub(log_odds(threshold))
.mask(df.pred.isna() | df.pred.round(2) == 0)
)
max_abs_score = df['score'].abs().max() * 1.5
if max_abs_score > 0:
df.score = df.score / max_abs_score
styled_df = make_styled_df(df[['token', 'pred']])
return list(zip(df.token, df.score)), styled_df
convo_text = """<|im_start|>assistant
what is your phone number<|im_end|>
<|im_start|>user
555 410 0423<|im_end|>"""
demo = gr.Interface(
fn=generate_highlighted_text,
theme="soft",
inputs=gr.Textbox(
label="Input Text",
info="If <|im_start|> is present it will treat input as formatted convo. if not it will format it as convo with 1 user message.",
# value="can you help me order some pizza",
value=convo_text,
lines=2
),
outputs=[
gr.HighlightedText(
label="EoT Predictions",
color_map="coolwarm",
scale=1.5,
),
gr.HTML(label="Raw scores",)
],
title="Turn Detector Debugger",
description="Visualize predicted turn endings. The coloring is based on log-odds, centered on the threshold.\n Red means agent should reply; Blue means agent should wait",
allow_flagging="never"
)
demo.launch()
|