jeradf commited on
Commit
ab7c317
·
verified ·
1 Parent(s): 35876ba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ import unicodedata
8
+ import re
9
+ import gradio as gr
10
+ from pprint import pprint
11
+
12
+
13
+
14
+ MODEL_ID = "livekit/turn-detector"
15
+ REVISION_ID = "v0.3.0-intl"
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, revision=REVISION_ID)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ MODEL_ID,
20
+ revision=REVISION_ID,
21
+ torch_dtype=torch.bfloat16,
22
+ device_map="auto",
23
+ )
24
+ model.eval()
25
+
26
+
27
+ EN_THRESHOLD = 0.0049
28
+ START_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|im_start|>')
29
+ EOU_TOKEN_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
30
+ NEWLINE_TOKEN_ID = tokenizer.convert_tokens_to_ids('\n')
31
+ USER_TOKEN_IDS = (
32
+ tokenizer.convert_tokens_to_ids('user'),
33
+ tokenizer.convert_tokens_to_ids('<|user|>')
34
+ )
35
+ SPECIAL_TOKENS = set([
36
+ NEWLINE_TOKEN_ID,
37
+ START_TOKEN_ID,
38
+ tokenizer.convert_tokens_to_ids('user'),
39
+ tokenizer.convert_tokens_to_ids('assistant'),
40
+ ])
41
+ CONTROL_TOKS = ['<|im_start|>', '<|im_end|>', 'user', 'assistant', '\n']
42
+
43
+
44
+ def normalize_text(text):
45
+ text = unicodedata.normalize("NFKC", text.lower())
46
+ text = ''.join(
47
+ ch for ch in text
48
+ if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"])
49
+ )
50
+ text = re.sub(r'\s+', ' ', text).strip()
51
+ return text
52
+
53
+
54
+ def format_input(text):
55
+ if '<|im_start|>' not in text:
56
+ # assume single user turn
57
+ text = {'role': 'user', 'content': normalize_text(text)}
58
+ text = tokenizer.apply_chat_template(
59
+ [text],
60
+ tokenize=False,
61
+ add_generation_prompt=True
62
+ )
63
+ return text
64
+
65
+
66
+ def log_odds(p, eps=0):
67
+ return np.log(p /(1 - p + eps))
68
+
69
+
70
+ def make_pred_mask(input_ids):
71
+ user_mask = [False] * len(input_ids)
72
+ is_user_role = False
73
+ for i in range(len(input_ids)-1):
74
+ if input_ids[i] == START_TOKEN_ID:
75
+ is_user_role = input_ids[i+1] in USER_TOKEN_IDS
76
+ if is_user_role and (input_ids[i] not in SPECIAL_TOKENS):
77
+ user_mask[i] = True
78
+ else:
79
+ user_mask[i] = False
80
+ return user_mask
81
+
82
+
83
+ def predict_eou(text):
84
+ text = format_input(text)
85
+ with torch.no_grad():
86
+ with torch.amp.autocast(model.device.type):
87
+ inputs = tokenizer.encode(
88
+ text,
89
+ add_special_tokens=False,
90
+ return_tensors="pt"
91
+ ).to(model.device)
92
+ outputs = model(inputs)
93
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
94
+ probs = probs.cpu().float().numpy()[:, :, EOU_TOKEN_ID].flatten()
95
+
96
+ input_ids = inputs.cpu().int().flatten().numpy()
97
+ mask = np.array(make_pred_mask(input_ids))
98
+ probs[~mask] = np.nan
99
+
100
+ tokens = [tokenizer.decode(id) for id in input_ids]
101
+ res = {'token':tokens,'pred':probs}
102
+ return pd.DataFrame(res)
103
+
104
+
105
+ def make_styled_df(df, thresh=EN_THRESHOLD, cmap="coolwarm"):
106
+ EPS = 1e-12
107
+ df = df.copy()
108
+ df = df[~df.token.isin(CONTROL_TOKS)]
109
+ df.token = df.token.replace({"\n": "⏎"," ": "␠",})
110
+
111
+ df['log_odds'] = (
112
+ df.pred.fillna(thresh)
113
+ .add(EPS)
114
+ .apply(log_odds).sub(log_odds(thresh))
115
+ .mask(df.pred.isna())
116
+ )
117
+ df['Prob(EoT) as %'] = df.pred.mul(100).fillna(0).astype(int)
118
+ vmin, vmax = df.log_odds.min(), df.log_odds.max()
119
+ vmax_abs = max(abs(vmin), abs(vmax)) * 1.5
120
+
121
+ fmt = (
122
+ df.drop(columns=['pred'])
123
+ .style
124
+ .bar(
125
+ subset=['log_odds'],
126
+ align="zero",
127
+ vmin=-vmax_abs,
128
+ vmax=vmax_abs,
129
+ cmap=cmap,
130
+ height=70,
131
+ width=100,
132
+ )
133
+ .text_gradient(subset=['log_odds'], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs)
134
+ .format(na_rep='', precision=1, subset=['log_odds'])
135
+ .format("{:3d}", subset=['Prob(EoT) as %'])
136
+ .hide(axis='index')
137
+ )
138
+ return fmt.to_html()
139
+
140
+
141
+ def generate_highlighted_text(text, threshold=EN_THRESHOLD):
142
+ eps = 1e-12
143
+ if not text:
144
+ return []
145
+
146
+ df = predict_eou(text)
147
+ df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"})
148
+ df = df[~df.token.isin(CONTROL_TOKS)]
149
+
150
+ df['score'] = (
151
+ df.pred.fillna(threshold)
152
+ .add(eps)
153
+ .apply(log_odds).sub(log_odds(threshold))
154
+ .mask(df.pred.isna() | df.pred.round(2) == 0)
155
+ )
156
+ max_abs_score = df['score'].abs().max() * 1.5
157
+
158
+ if max_abs_score > 0:
159
+ df.score = df.score / max_abs_score
160
+
161
+ styled_df = make_styled_df(df[['token', 'pred']])
162
+ return list(zip(df.token, df.score)), styled_df
163
+
164
+
165
+
166
+ convo_text = """<|im_start|>assistant
167
+ what is your phone number<|im_end|>
168
+ <|im_start|>user
169
+ 555 410 0423<|im_end|>"""
170
+
171
+
172
+ demo = gr.Interface(
173
+ fn=generate_highlighted_text,
174
+ theme="soft",
175
+ inputs=gr.Textbox(
176
+ label="Input Text",
177
+ # value="can you help me order some pizza",
178
+ value=convo_text,
179
+ lines=2
180
+ ),
181
+ outputs=[
182
+ gr.HighlightedText(
183
+ label="EoT Predictions",
184
+ color_map="coolwarm",
185
+ scale=1.5,
186
+ ),
187
+ gr.HTML(label="Raw scores",)
188
+ ],
189
+ title="Turn Detector Debugger",
190
+ description="Visualize predicted turn endings. The coloring is based on log-odds, centered on the threshold.\n Red means agent should reply; Blue means agent should wait",
191
+ allow_flagging="never"
192
+ )
193
+
194
+ demo.launch()