import random import numpy as np import pandas as pd import torch from transformers import AutoModelForCausalLM, AutoTokenizer import unicodedata import re import gradio as gr from pprint import pprint MODEL_ID = "livekit/turn-detector" REVISION_ID = "v0.3.0-intl" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, revision=REVISION_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, revision=REVISION_ID, torch_dtype=torch.bfloat16, device_map="auto", ) model.eval() EN_THRESHOLD = 0.0049 START_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|im_start|>') EOU_TOKEN_ID = tokenizer.convert_tokens_to_ids("<|im_end|>") NEWLINE_TOKEN_ID = tokenizer.convert_tokens_to_ids('\n') USER_TOKEN_IDS = ( tokenizer.convert_tokens_to_ids('user'), tokenizer.convert_tokens_to_ids('<|user|>') ) SPECIAL_TOKENS = set([ NEWLINE_TOKEN_ID, START_TOKEN_ID, tokenizer.convert_tokens_to_ids('user'), tokenizer.convert_tokens_to_ids('assistant'), ]) CONTROL_TOKS = ['<|im_start|>', '<|im_end|>', 'user', 'assistant', '\n'] def normalize_text(text): text = unicodedata.normalize("NFKC", text.lower()) text = ''.join( ch for ch in text if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"]) ) text = re.sub(r'\s+', ' ', text).strip() return text def format_input(text): if '<|im_start|>' not in text: # assume single user turn text = {'role': 'user', 'content': normalize_text(text)} text = tokenizer.apply_chat_template( [text], tokenize=False, add_generation_prompt=True ) return text def log_odds(p, eps=0): return np.log(p /(1 - p + eps)) def make_pred_mask(input_ids): user_mask = [False] * len(input_ids) is_user_role = False for i in range(len(input_ids)-1): if input_ids[i] == START_TOKEN_ID: is_user_role = input_ids[i+1] in USER_TOKEN_IDS if is_user_role and (input_ids[i] not in SPECIAL_TOKENS): user_mask[i] = True else: user_mask[i] = False return user_mask def predict_eou(text): text = format_input(text) with torch.no_grad(): with torch.amp.autocast(model.device.type): inputs = tokenizer.encode( text, add_special_tokens=False, return_tensors="pt" ).to(model.device) outputs = model(inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) probs = probs.cpu().float().numpy()[:, :, EOU_TOKEN_ID].flatten() input_ids = inputs.cpu().int().flatten().numpy() mask = np.array(make_pred_mask(input_ids)) probs[~mask] = np.nan tokens = [tokenizer.decode(id) for id in input_ids] res = {'token':tokens,'pred':probs} return pd.DataFrame(res) def make_styled_df(df, thresh=EN_THRESHOLD, cmap="coolwarm"): EPS = 1e-12 df = df.copy() df = df[~df.token.isin(CONTROL_TOKS)] df.token = df.token.replace({"\n": "⏎"," ": "␠",}) df['log_odds'] = ( df.pred.fillna(thresh) .add(EPS) .apply(log_odds).sub(log_odds(thresh)) .mask(df.pred.isna()) ) df['Prob(EoT) as %'] = df.pred.mul(100).fillna(0).astype(int) vmin, vmax = df.log_odds.min(), df.log_odds.max() vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 fmt = ( df.drop(columns=['pred']) .style .bar( subset=['log_odds'], align="zero", vmin=-vmax_abs, vmax=vmax_abs, cmap=cmap, height=70, width=100, ) .text_gradient(subset=['log_odds'], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs) .format(na_rep='', precision=1, subset=['log_odds']) .format("{:3d}", subset=['Prob(EoT) as %']) .hide(axis='index') ) return fmt.to_html() def generate_highlighted_text(text, threshold=EN_THRESHOLD): eps = 1e-12 if not text: return [] df = predict_eou(text) df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"}) df = df[~df.token.isin(CONTROL_TOKS)] df['score'] = ( df.pred.fillna(threshold) .add(eps) .apply(log_odds).sub(log_odds(threshold)) .mask(df.pred.isna() | df.pred.round(2) == 0) ) max_abs_score = df['score'].abs().max() * 1.5 if max_abs_score > 0: df.score = df.score / max_abs_score styled_df = make_styled_df(df[['token', 'pred']]) return list(zip(df.token, df.score)), styled_df convo_text = """<|im_start|>assistant what is your phone number<|im_end|> <|im_start|>user 555 410 0423<|im_end|>""" demo = gr.Interface( fn=generate_highlighted_text, theme="soft", inputs=gr.Textbox( label="Input Text", info="If <|im_start|> is present it will treat input as formatted convo. if not it will format it as convo with 1 user message.", # value="can you help me order some pizza", value=convo_text, lines=2 ), outputs=[ gr.HighlightedText( label="EoT Predictions", color_map="coolwarm", scale=1.5, ), gr.HTML(label="Raw scores",) ], title="Turn Detector Debugger", description="Visualize predicted turn endings. The coloring is based on log-odds, centered on the threshold.\n Red means agent should reply; Blue means agent should wait", allow_flagging="never" ) demo.launch()