Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
|
7 |
+
import unicodedata
|
8 |
+
import re
|
9 |
+
import gradio as gr
|
10 |
+
from pprint import pprint
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
MODEL_ID = "livekit/turn-detector"
|
15 |
+
REVISION_ID = "v0.3.0-intl"
|
16 |
+
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, revision=REVISION_ID)
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(
|
19 |
+
MODEL_ID,
|
20 |
+
revision=REVISION_ID,
|
21 |
+
torch_dtype=torch.bfloat16,
|
22 |
+
device_map="auto",
|
23 |
+
)
|
24 |
+
model.eval()
|
25 |
+
|
26 |
+
|
27 |
+
EN_THRESHOLD = 0.0049
|
28 |
+
START_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|im_start|>')
|
29 |
+
EOU_TOKEN_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
30 |
+
NEWLINE_TOKEN_ID = tokenizer.convert_tokens_to_ids('\n')
|
31 |
+
USER_TOKEN_IDS = (
|
32 |
+
tokenizer.convert_tokens_to_ids('user'),
|
33 |
+
tokenizer.convert_tokens_to_ids('<|user|>')
|
34 |
+
)
|
35 |
+
SPECIAL_TOKENS = set([
|
36 |
+
NEWLINE_TOKEN_ID,
|
37 |
+
START_TOKEN_ID,
|
38 |
+
tokenizer.convert_tokens_to_ids('user'),
|
39 |
+
tokenizer.convert_tokens_to_ids('assistant'),
|
40 |
+
])
|
41 |
+
CONTROL_TOKS = ['<|im_start|>', '<|im_end|>', 'user', 'assistant', '\n']
|
42 |
+
|
43 |
+
|
44 |
+
def normalize_text(text):
|
45 |
+
text = unicodedata.normalize("NFKC", text.lower())
|
46 |
+
text = ''.join(
|
47 |
+
ch for ch in text
|
48 |
+
if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"])
|
49 |
+
)
|
50 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
51 |
+
return text
|
52 |
+
|
53 |
+
|
54 |
+
def format_input(text):
|
55 |
+
if '<|im_start|>' not in text:
|
56 |
+
# assume single user turn
|
57 |
+
text = {'role': 'user', 'content': normalize_text(text)}
|
58 |
+
text = tokenizer.apply_chat_template(
|
59 |
+
[text],
|
60 |
+
tokenize=False,
|
61 |
+
add_generation_prompt=True
|
62 |
+
)
|
63 |
+
return text
|
64 |
+
|
65 |
+
|
66 |
+
def log_odds(p, eps=0):
|
67 |
+
return np.log(p /(1 - p + eps))
|
68 |
+
|
69 |
+
|
70 |
+
def make_pred_mask(input_ids):
|
71 |
+
user_mask = [False] * len(input_ids)
|
72 |
+
is_user_role = False
|
73 |
+
for i in range(len(input_ids)-1):
|
74 |
+
if input_ids[i] == START_TOKEN_ID:
|
75 |
+
is_user_role = input_ids[i+1] in USER_TOKEN_IDS
|
76 |
+
if is_user_role and (input_ids[i] not in SPECIAL_TOKENS):
|
77 |
+
user_mask[i] = True
|
78 |
+
else:
|
79 |
+
user_mask[i] = False
|
80 |
+
return user_mask
|
81 |
+
|
82 |
+
|
83 |
+
def predict_eou(text):
|
84 |
+
text = format_input(text)
|
85 |
+
with torch.no_grad():
|
86 |
+
with torch.amp.autocast(model.device.type):
|
87 |
+
inputs = tokenizer.encode(
|
88 |
+
text,
|
89 |
+
add_special_tokens=False,
|
90 |
+
return_tensors="pt"
|
91 |
+
).to(model.device)
|
92 |
+
outputs = model(inputs)
|
93 |
+
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
94 |
+
probs = probs.cpu().float().numpy()[:, :, EOU_TOKEN_ID].flatten()
|
95 |
+
|
96 |
+
input_ids = inputs.cpu().int().flatten().numpy()
|
97 |
+
mask = np.array(make_pred_mask(input_ids))
|
98 |
+
probs[~mask] = np.nan
|
99 |
+
|
100 |
+
tokens = [tokenizer.decode(id) for id in input_ids]
|
101 |
+
res = {'token':tokens,'pred':probs}
|
102 |
+
return pd.DataFrame(res)
|
103 |
+
|
104 |
+
|
105 |
+
def make_styled_df(df, thresh=EN_THRESHOLD, cmap="coolwarm"):
|
106 |
+
EPS = 1e-12
|
107 |
+
df = df.copy()
|
108 |
+
df = df[~df.token.isin(CONTROL_TOKS)]
|
109 |
+
df.token = df.token.replace({"\n": "⏎"," ": "␠",})
|
110 |
+
|
111 |
+
df['log_odds'] = (
|
112 |
+
df.pred.fillna(thresh)
|
113 |
+
.add(EPS)
|
114 |
+
.apply(log_odds).sub(log_odds(thresh))
|
115 |
+
.mask(df.pred.isna())
|
116 |
+
)
|
117 |
+
df['Prob(EoT) as %'] = df.pred.mul(100).fillna(0).astype(int)
|
118 |
+
vmin, vmax = df.log_odds.min(), df.log_odds.max()
|
119 |
+
vmax_abs = max(abs(vmin), abs(vmax)) * 1.5
|
120 |
+
|
121 |
+
fmt = (
|
122 |
+
df.drop(columns=['pred'])
|
123 |
+
.style
|
124 |
+
.bar(
|
125 |
+
subset=['log_odds'],
|
126 |
+
align="zero",
|
127 |
+
vmin=-vmax_abs,
|
128 |
+
vmax=vmax_abs,
|
129 |
+
cmap=cmap,
|
130 |
+
height=70,
|
131 |
+
width=100,
|
132 |
+
)
|
133 |
+
.text_gradient(subset=['log_odds'], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs)
|
134 |
+
.format(na_rep='', precision=1, subset=['log_odds'])
|
135 |
+
.format("{:3d}", subset=['Prob(EoT) as %'])
|
136 |
+
.hide(axis='index')
|
137 |
+
)
|
138 |
+
return fmt.to_html()
|
139 |
+
|
140 |
+
|
141 |
+
def generate_highlighted_text(text, threshold=EN_THRESHOLD):
|
142 |
+
eps = 1e-12
|
143 |
+
if not text:
|
144 |
+
return []
|
145 |
+
|
146 |
+
df = predict_eou(text)
|
147 |
+
df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"})
|
148 |
+
df = df[~df.token.isin(CONTROL_TOKS)]
|
149 |
+
|
150 |
+
df['score'] = (
|
151 |
+
df.pred.fillna(threshold)
|
152 |
+
.add(eps)
|
153 |
+
.apply(log_odds).sub(log_odds(threshold))
|
154 |
+
.mask(df.pred.isna() | df.pred.round(2) == 0)
|
155 |
+
)
|
156 |
+
max_abs_score = df['score'].abs().max() * 1.5
|
157 |
+
|
158 |
+
if max_abs_score > 0:
|
159 |
+
df.score = df.score / max_abs_score
|
160 |
+
|
161 |
+
styled_df = make_styled_df(df[['token', 'pred']])
|
162 |
+
return list(zip(df.token, df.score)), styled_df
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
convo_text = """<|im_start|>assistant
|
167 |
+
what is your phone number<|im_end|>
|
168 |
+
<|im_start|>user
|
169 |
+
555 410 0423<|im_end|>"""
|
170 |
+
|
171 |
+
|
172 |
+
demo = gr.Interface(
|
173 |
+
fn=generate_highlighted_text,
|
174 |
+
theme="soft",
|
175 |
+
inputs=gr.Textbox(
|
176 |
+
label="Input Text",
|
177 |
+
# value="can you help me order some pizza",
|
178 |
+
value=convo_text,
|
179 |
+
lines=2
|
180 |
+
),
|
181 |
+
outputs=[
|
182 |
+
gr.HighlightedText(
|
183 |
+
label="EoT Predictions",
|
184 |
+
color_map="coolwarm",
|
185 |
+
scale=1.5,
|
186 |
+
),
|
187 |
+
gr.HTML(label="Raw scores",)
|
188 |
+
],
|
189 |
+
title="Turn Detector Debugger",
|
190 |
+
description="Visualize predicted turn endings. The coloring is based on log-odds, centered on the threshold.\n Red means agent should reply; Blue means agent should wait",
|
191 |
+
allow_flagging="never"
|
192 |
+
)
|
193 |
+
|
194 |
+
demo.launch()
|