eot-viz / app.py
jeradf's picture
Update app.py
54e69b4 verified
import random
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import unicodedata
import re
import gradio as gr
from pprint import pprint
MODEL_ID = "livekit/turn-detector"
REVISION_ID = "v0.3.0-intl"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, revision=REVISION_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
revision=REVISION_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.eval()
EN_THRESHOLD = 0.0049
START_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|im_start|>')
EOU_TOKEN_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
NEWLINE_TOKEN_ID = tokenizer.convert_tokens_to_ids('\n')
USER_TOKEN_IDS = (
tokenizer.convert_tokens_to_ids('user'),
tokenizer.convert_tokens_to_ids('<|user|>')
)
SPECIAL_TOKENS = set([
NEWLINE_TOKEN_ID,
START_TOKEN_ID,
tokenizer.convert_tokens_to_ids('user'),
tokenizer.convert_tokens_to_ids('assistant'),
])
CONTROL_TOKS = ['<|im_start|>', '<|im_end|>', 'user', 'assistant', '\n']
def normalize_text(text):
text = unicodedata.normalize("NFKC", text.lower())
text = ''.join(
ch for ch in text
if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"])
)
text = re.sub(r'\s+', ' ', text).strip()
return text
def format_input(text):
if '<|im_start|>' not in text:
# assume single user turn
text = {'role': 'user', 'content': normalize_text(text)}
text = tokenizer.apply_chat_template(
[text],
tokenize=False,
add_generation_prompt=True
)
return text
def log_odds(p, eps=0):
return np.log(p /(1 - p + eps))
def make_pred_mask(input_ids):
user_mask = [False] * len(input_ids)
is_user_role = False
for i in range(len(input_ids)-1):
if input_ids[i] == START_TOKEN_ID:
is_user_role = input_ids[i+1] in USER_TOKEN_IDS
if is_user_role and (input_ids[i] not in SPECIAL_TOKENS):
user_mask[i] = True
else:
user_mask[i] = False
return user_mask
def predict_eou(text):
text = format_input(text)
with torch.no_grad():
with torch.amp.autocast(model.device.type):
inputs = tokenizer.encode(
text,
add_special_tokens=False,
return_tensors="pt"
).to(model.device)
outputs = model(inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
probs = probs.cpu().float().numpy()[:, :, EOU_TOKEN_ID].flatten()
input_ids = inputs.cpu().int().flatten().numpy()
mask = np.array(make_pred_mask(input_ids))
probs[~mask] = np.nan
tokens = [tokenizer.decode(id) for id in input_ids]
res = {'token':tokens,'pred':probs}
return pd.DataFrame(res)
def make_styled_df(df, thresh=EN_THRESHOLD, cmap="coolwarm"):
EPS = 1e-12
df = df.copy()
df = df[~df.token.isin(CONTROL_TOKS)]
df.token = df.token.replace({"\n": "⏎"," ": "␠",})
df['log_odds'] = (
df.pred.fillna(thresh)
.add(EPS)
.apply(log_odds).sub(log_odds(thresh))
.mask(df.pred.isna())
)
df['Prob(EoT) as %'] = df.pred.mul(100).fillna(0).astype(int)
vmin, vmax = df.log_odds.min(), df.log_odds.max()
vmax_abs = max(abs(vmin), abs(vmax)) * 1.5
fmt = (
df.drop(columns=['pred'])
.style
.bar(
subset=['log_odds'],
align="zero",
vmin=-vmax_abs,
vmax=vmax_abs,
cmap=cmap,
height=70,
width=100,
)
.text_gradient(subset=['log_odds'], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs)
.format(na_rep='', precision=1, subset=['log_odds'])
.format("{:3d}", subset=['Prob(EoT) as %'])
.hide(axis='index')
)
return fmt.to_html()
def generate_highlighted_text(text, threshold=EN_THRESHOLD):
eps = 1e-12
if not text:
return []
df = predict_eou(text)
df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"})
df = df[~df.token.isin(CONTROL_TOKS)]
df['score'] = (
df.pred.fillna(threshold)
.add(eps)
.apply(log_odds).sub(log_odds(threshold))
.mask(df.pred.isna() | df.pred.round(2) == 0)
)
max_abs_score = df['score'].abs().max() * 1.5
if max_abs_score > 0:
df.score = df.score / max_abs_score
styled_df = make_styled_df(df[['token', 'pred']])
return list(zip(df.token, df.score)), styled_df
convo_text = """<|im_start|>assistant
what is your phone number<|im_end|>
<|im_start|>user
555 410 0423<|im_end|>"""
demo = gr.Interface(
fn=generate_highlighted_text,
theme="soft",
inputs=gr.Textbox(
label="Input Text",
info="If <|im_start|> is present it will treat input as formatted convo. if not it will format it as convo with 1 user message.",
# value="can you help me order some pizza",
value=convo_text,
lines=2
),
outputs=[
gr.HighlightedText(
label="EoT Predictions",
color_map="coolwarm",
scale=1.5,
),
gr.HTML(label="Raw scores",)
],
title="Turn Detector Debugger",
description="Visualize predicted turn endings. The coloring is based on log-odds, centered on the threshold.\n Red means agent should reply; Blue means agent should wait",
allow_flagging="never"
)
demo.launch()