Spaces:
Running
on
Zero
Running
on
Zero
Revert "add suggestions cleanning"
Browse filesThis reverts commit a94f020eba50b5aec3746c327d42f6b8c6344ac1.
app.py
CHANGED
@@ -10,7 +10,6 @@ from termcolor import cprint
|
|
10 |
|
11 |
# 初始化簡體到繁體轉換器
|
12 |
cc = OpenCC('s2t')
|
13 |
-
tokenizer = None
|
14 |
|
15 |
# 可選模型列表
|
16 |
MODEL_LIST = [
|
@@ -28,82 +27,10 @@ MODEL_LIST = [
|
|
28 |
"Epiculous/Violet_Twilight-v0.2",
|
29 |
]
|
30 |
|
31 |
-
def clean_suggestions(suggestions: list[str], max_levels: int) -> list[str]:
|
32 |
-
"""
|
33 |
-
清洗建议列表:
|
34 |
-
1. 对每条建议用 tokenizer.tokenize 得到 token 序列。
|
35 |
-
2. 构建前缀树,将所有 token 序列插入。
|
36 |
-
3. 遍历前缀树,仅在深度 <= max_levels 且该节点有子节点时,提取对应 token 前缀。
|
37 |
-
4. 将这些 token 前缀转换回文本并去重,返回列表。
|
38 |
-
"""
|
39 |
-
# 定义 Trie 节点结构
|
40 |
-
class TrieNode:
|
41 |
-
__slots__ = ("children", "count")
|
42 |
-
def __init__(self):
|
43 |
-
self.children: dict[str, TrieNode] = {}
|
44 |
-
self.count: int = 0 # 可以记录有多少序列经过此节点(可选)
|
45 |
-
|
46 |
-
# 构建前缀树
|
47 |
-
root = TrieNode()
|
48 |
-
token_seqs: list[list[str]] = []
|
49 |
-
|
50 |
-
for text in suggestions:
|
51 |
-
# tokenizer.tokenize 可能返回子词 token 列表
|
52 |
-
try:
|
53 |
-
toks = tokenizer.tokenize(text)
|
54 |
-
except Exception:
|
55 |
-
# 如果 tokenizer 不支持直接 tokenize raw text,可以先用 basic tokenization,如按空白分割
|
56 |
-
toks = text.split()
|
57 |
-
if not toks:
|
58 |
-
continue
|
59 |
-
token_seqs.append(toks)
|
60 |
-
node = root
|
61 |
-
node.count += 1
|
62 |
-
for tok in toks:
|
63 |
-
if tok not in node.children:
|
64 |
-
node.children[tok] = TrieNode()
|
65 |
-
node = node.children[tok]
|
66 |
-
node.count += 1
|
67 |
-
|
68 |
-
# 遍历 Trie,收集深度 <= max_levels 且有子节点的前缀序列
|
69 |
-
results_prefix_tokens: set[tuple[str, ...]] = set()
|
70 |
-
|
71 |
-
def dfs(node: TrieNode, path: list[str], depth: int):
|
72 |
-
# node: 当前 TrieNode; path: 已走过的 token 列表; depth: len(path)
|
73 |
-
if depth > max_levels:
|
74 |
-
return
|
75 |
-
# 如果当前节点有子节点,且 depth>0 (排除根节点本身),则为一个候选前缀
|
76 |
-
if depth > 0 and node.children:
|
77 |
-
results_prefix_tokens.add(tuple(path))
|
78 |
-
# 继续往下遍历,直到 depth == max_levels
|
79 |
-
if depth == max_levels:
|
80 |
-
return
|
81 |
-
for tok, child in node.children.items():
|
82 |
-
path.append(tok)
|
83 |
-
dfs(child, path, depth + 1)
|
84 |
-
path.pop()
|
85 |
-
|
86 |
-
dfs(root, [], 0)
|
87 |
-
|
88 |
-
# 将 token 前缀转换回字符串
|
89 |
-
cleaned: set[str] = set()
|
90 |
-
for tok_prefix in results_prefix_tokens:
|
91 |
-
try:
|
92 |
-
# tokenizer.convert_tokens_to_string 在大多数 tokenizer 支持
|
93 |
-
text_pref = tokenizer.convert_tokens_to_string(list(tok_prefix)).strip()
|
94 |
-
except Exception:
|
95 |
-
# fallback: 直接拼接 token(可能需要根据 tokenizer 规范加空格或直接连起来)
|
96 |
-
text_pref = "".join(tok_prefix).strip()
|
97 |
-
if text_pref:
|
98 |
-
cleaned.add(text_pref)
|
99 |
-
|
100 |
-
# 返回去重之后的列表
|
101 |
-
return list(cleaned)
|
102 |
|
103 |
@lru_cache(maxsize=8)
|
104 |
def get_pipeline(model_name):
|
105 |
-
|
106 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
107 |
mdl = AutoModelForCausalLM.from_pretrained(
|
108 |
model_name, weights_only=False, trust_remote_code=True
|
109 |
)
|
@@ -111,10 +38,10 @@ def get_pipeline(model_name):
|
|
111 |
mdl.to("cuda")
|
112 |
except Exception as e:
|
113 |
print(f'Error: {e}')
|
114 |
-
return pipeline("text-generation", model=mdl, tokenizer=
|
115 |
|
116 |
@spaces.GPU
|
117 |
-
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty
|
118 |
"""
|
119 |
使用 Diverse Beam Search 產生 m 條候選:
|
120 |
- num_beams = m
|
@@ -131,7 +58,7 @@ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty, max
|
|
131 |
"early_stopping": True,
|
132 |
}
|
133 |
if diversity_penalty and diversity_penalty > 0:
|
134 |
-
valid_group =
|
135 |
gen_kwargs["num_beam_groups"] = valid_group
|
136 |
gen_kwargs["diversity_penalty"] = float(diversity_penalty)
|
137 |
|
@@ -146,7 +73,6 @@ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty, max
|
|
146 |
converted = cc.convert(snippet).strip()
|
147 |
suggestions.add(converted)
|
148 |
suggestions = list(suggestions)
|
149 |
-
suggestions = clean_suggestions(suggestions, max_prefix_levels)
|
150 |
|
151 |
return update(choices=suggestions, value=None)
|
152 |
|
@@ -269,10 +195,6 @@ with gr.Blocks(css=custom_css) as demo:
|
|
269 |
minimum=0.0, maximum=2.0, step=0.1, value=1.0,
|
270 |
label="多樣性懲罰 (diversity_penalty)"
|
271 |
)
|
272 |
-
prefix_levels_slider = gr.Slider(
|
273 |
-
minimum=1, maximum=5, step=1, value=2,
|
274 |
-
label="Clean 前綴深度 (max_levels)"
|
275 |
-
)
|
276 |
|
277 |
# 綁定���件
|
278 |
predict_button.click(
|
@@ -283,14 +205,13 @@ with gr.Blocks(css=custom_css) as demo:
|
|
283 |
k_slider,
|
284 |
m_slider,
|
285 |
group_slider,
|
286 |
-
diversity_penalty_slider
|
287 |
-
prefix_levels_slider # 新增
|
288 |
],
|
289 |
outputs=suggestions,
|
290 |
)
|
291 |
input_text.change(
|
292 |
-
fn=lambda txt, mdl, k, m, g, d, auto
|
293 |
-
suggest_next(txt, mdl, k, m, g, d
|
294 |
if auto else update(choices=[], value=None)
|
295 |
),
|
296 |
inputs=[
|
@@ -300,8 +221,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
300 |
m_slider,
|
301 |
group_slider,
|
302 |
diversity_penalty_slider,
|
303 |
-
auto_predict
|
304 |
-
prefix_levels_slider # 新增
|
305 |
],
|
306 |
outputs=suggestions,
|
307 |
)
|
|
|
10 |
|
11 |
# 初始化簡體到繁體轉換器
|
12 |
cc = OpenCC('s2t')
|
|
|
13 |
|
14 |
# 可選模型列表
|
15 |
MODEL_LIST = [
|
|
|
27 |
"Epiculous/Violet_Twilight-v0.2",
|
28 |
]
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
@lru_cache(maxsize=8)
|
32 |
def get_pipeline(model_name):
|
33 |
+
tok = AutoTokenizer.from_pretrained(model_name)
|
|
|
34 |
mdl = AutoModelForCausalLM.from_pretrained(
|
35 |
model_name, weights_only=False, trust_remote_code=True
|
36 |
)
|
|
|
38 |
mdl.to("cuda")
|
39 |
except Exception as e:
|
40 |
print(f'Error: {e}')
|
41 |
+
return pipeline("text-generation", model=mdl, tokenizer=tok, device=0)
|
42 |
|
43 |
@spaces.GPU
|
44 |
+
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
|
45 |
"""
|
46 |
使用 Diverse Beam Search 產生 m 條候選:
|
47 |
- num_beams = m
|
|
|
58 |
"early_stopping": True,
|
59 |
}
|
60 |
if diversity_penalty and diversity_penalty > 0:
|
61 |
+
valid_group = gcd(m, num_beam_groups)
|
62 |
gen_kwargs["num_beam_groups"] = valid_group
|
63 |
gen_kwargs["diversity_penalty"] = float(diversity_penalty)
|
64 |
|
|
|
73 |
converted = cc.convert(snippet).strip()
|
74 |
suggestions.add(converted)
|
75 |
suggestions = list(suggestions)
|
|
|
76 |
|
77 |
return update(choices=suggestions, value=None)
|
78 |
|
|
|
195 |
minimum=0.0, maximum=2.0, step=0.1, value=1.0,
|
196 |
label="多樣性懲罰 (diversity_penalty)"
|
197 |
)
|
|
|
|
|
|
|
|
|
198 |
|
199 |
# 綁定���件
|
200 |
predict_button.click(
|
|
|
205 |
k_slider,
|
206 |
m_slider,
|
207 |
group_slider,
|
208 |
+
diversity_penalty_slider
|
|
|
209 |
],
|
210 |
outputs=suggestions,
|
211 |
)
|
212 |
input_text.change(
|
213 |
+
fn=lambda txt, mdl, k, m, g, d, auto: (
|
214 |
+
suggest_next(txt, mdl, k, m, g, d)
|
215 |
if auto else update(choices=[], value=None)
|
216 |
),
|
217 |
inputs=[
|
|
|
221 |
m_slider,
|
222 |
group_slider,
|
223 |
diversity_penalty_slider,
|
224 |
+
auto_predict
|
|
|
225 |
],
|
226 |
outputs=suggestions,
|
227 |
)
|