|
import random |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
import unicodedata |
|
import re |
|
import gradio as gr |
|
from pprint import pprint |
|
|
|
|
|
|
|
MODEL_ID = "livekit/turn-detector" |
|
REVISION_ID = "v0.3.0-intl" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, revision=REVISION_ID) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
revision=REVISION_ID, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
) |
|
model.eval() |
|
|
|
|
|
EN_THRESHOLD = 0.0049 |
|
START_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|im_start|>') |
|
EOU_TOKEN_ID = tokenizer.convert_tokens_to_ids("<|im_end|>") |
|
NEWLINE_TOKEN_ID = tokenizer.convert_tokens_to_ids('\n') |
|
USER_TOKEN_IDS = ( |
|
tokenizer.convert_tokens_to_ids('user'), |
|
tokenizer.convert_tokens_to_ids('<|user|>') |
|
) |
|
SPECIAL_TOKENS = set([ |
|
NEWLINE_TOKEN_ID, |
|
START_TOKEN_ID, |
|
tokenizer.convert_tokens_to_ids('user'), |
|
tokenizer.convert_tokens_to_ids('assistant'), |
|
]) |
|
CONTROL_TOKS = ['<|im_start|>', '<|im_end|>', 'user', 'assistant', '\n'] |
|
|
|
|
|
def normalize_text(text): |
|
text = unicodedata.normalize("NFKC", text.lower()) |
|
text = ''.join( |
|
ch for ch in text |
|
if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"]) |
|
) |
|
text = re.sub(r'\s+', ' ', text).strip() |
|
return text |
|
|
|
|
|
def format_input(text): |
|
if '<|im_start|>' not in text: |
|
|
|
text = {'role': 'user', 'content': normalize_text(text)} |
|
text = tokenizer.apply_chat_template( |
|
[text], |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
return text |
|
|
|
|
|
def log_odds(p, eps=0): |
|
return np.log(p /(1 - p + eps)) |
|
|
|
|
|
def make_pred_mask(input_ids): |
|
user_mask = [False] * len(input_ids) |
|
is_user_role = False |
|
for i in range(len(input_ids)-1): |
|
if input_ids[i] == START_TOKEN_ID: |
|
is_user_role = input_ids[i+1] in USER_TOKEN_IDS |
|
if is_user_role and (input_ids[i] not in SPECIAL_TOKENS): |
|
user_mask[i] = True |
|
else: |
|
user_mask[i] = False |
|
return user_mask |
|
|
|
|
|
def predict_eou(text): |
|
text = format_input(text) |
|
with torch.no_grad(): |
|
with torch.amp.autocast(model.device.type): |
|
inputs = tokenizer.encode( |
|
text, |
|
add_special_tokens=False, |
|
return_tensors="pt" |
|
).to(model.device) |
|
outputs = model(inputs) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
probs = probs.cpu().float().numpy()[:, :, EOU_TOKEN_ID].flatten() |
|
|
|
input_ids = inputs.cpu().int().flatten().numpy() |
|
mask = np.array(make_pred_mask(input_ids)) |
|
probs[~mask] = np.nan |
|
|
|
tokens = [tokenizer.decode(id) for id in input_ids] |
|
res = {'token':tokens,'pred':probs} |
|
return pd.DataFrame(res) |
|
|
|
|
|
def make_styled_df(df, thresh=EN_THRESHOLD, cmap="coolwarm"): |
|
EPS = 1e-12 |
|
df = df.copy() |
|
df = df[~df.token.isin(CONTROL_TOKS)] |
|
df.token = df.token.replace({"\n": "β"," ": "β ",}) |
|
|
|
df['log_odds'] = ( |
|
df.pred.fillna(thresh) |
|
.add(EPS) |
|
.apply(log_odds).sub(log_odds(thresh)) |
|
.mask(df.pred.isna()) |
|
) |
|
df['Prob(EoT) as %'] = df.pred.mul(100).fillna(0).astype(int) |
|
vmin, vmax = df.log_odds.min(), df.log_odds.max() |
|
vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 |
|
|
|
fmt = ( |
|
df.drop(columns=['pred']) |
|
.style |
|
.bar( |
|
subset=['log_odds'], |
|
align="zero", |
|
vmin=-vmax_abs, |
|
vmax=vmax_abs, |
|
cmap=cmap, |
|
height=70, |
|
width=100, |
|
) |
|
.text_gradient(subset=['log_odds'], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs) |
|
.format(na_rep='', precision=1, subset=['log_odds']) |
|
.format("{:3d}", subset=['Prob(EoT) as %']) |
|
.hide(axis='index') |
|
) |
|
return fmt.to_html() |
|
|
|
|
|
def generate_highlighted_text(text, threshold=EN_THRESHOLD): |
|
eps = 1e-12 |
|
if not text: |
|
return [] |
|
|
|
df = predict_eou(text) |
|
df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"}) |
|
df = df[~df.token.isin(CONTROL_TOKS)] |
|
|
|
df['score'] = ( |
|
df.pred.fillna(threshold) |
|
.add(eps) |
|
.apply(log_odds).sub(log_odds(threshold)) |
|
.mask(df.pred.isna() | df.pred.round(2) == 0) |
|
) |
|
max_abs_score = df['score'].abs().max() * 1.5 |
|
|
|
if max_abs_score > 0: |
|
df.score = df.score / max_abs_score |
|
|
|
styled_df = make_styled_df(df[['token', 'pred']]) |
|
return list(zip(df.token, df.score)), styled_df |
|
|
|
|
|
|
|
convo_text = """<|im_start|>assistant |
|
what is your phone number<|im_end|> |
|
<|im_start|>user |
|
555 410 0423<|im_end|>""" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_highlighted_text, |
|
theme="soft", |
|
inputs=gr.Textbox( |
|
label="Input Text", |
|
info="If <|im_start|> is present it will treat input as formatted convo. if not it will format it as convo with 1 user message.", |
|
|
|
value=convo_text, |
|
lines=2 |
|
), |
|
outputs=[ |
|
gr.HighlightedText( |
|
label="EoT Predictions", |
|
color_map="coolwarm", |
|
scale=1.5, |
|
), |
|
gr.HTML(label="Raw scores",) |
|
], |
|
title="Turn Detector Debugger", |
|
description="Visualize predicted turn endings. The coloring is based on log-odds, centered on the threshold.\n Red means agent should reply; Blue means agent should wait", |
|
allow_flagging="never" |
|
) |
|
|
|
demo.launch() |
|
|