Luigi commited on
Commit
05d7c3d
·
1 Parent(s): a94f020

Revert "add suggestions cleanning"

Browse files

This reverts commit a94f020eba50b5aec3746c327d42f6b8c6344ac1.

Files changed (1) hide show
  1. app.py +8 -88
app.py CHANGED
@@ -10,7 +10,6 @@ from termcolor import cprint
10
 
11
  # 初始化簡體到繁體轉換器
12
  cc = OpenCC('s2t')
13
- tokenizer = None
14
 
15
  # 可選模型列表
16
  MODEL_LIST = [
@@ -28,82 +27,10 @@ MODEL_LIST = [
28
  "Epiculous/Violet_Twilight-v0.2",
29
  ]
30
 
31
- def clean_suggestions(suggestions: list[str], max_levels: int) -> list[str]:
32
- """
33
- 清洗建议列表:
34
- 1. 对每条建议用 tokenizer.tokenize 得到 token 序列。
35
- 2. 构建前缀树,将所有 token 序列插入。
36
- 3. 遍历前缀树,仅在深度 <= max_levels 且该节点有子节点时,提取对应 token 前缀。
37
- 4. 将这些 token 前缀转换回文本并去重,返回列表。
38
- """
39
- # 定义 Trie 节点结构
40
- class TrieNode:
41
- __slots__ = ("children", "count")
42
- def __init__(self):
43
- self.children: dict[str, TrieNode] = {}
44
- self.count: int = 0 # 可以记录有多少序列经过此节点(可选)
45
-
46
- # 构建前缀树
47
- root = TrieNode()
48
- token_seqs: list[list[str]] = []
49
-
50
- for text in suggestions:
51
- # tokenizer.tokenize 可能返回子词 token 列表
52
- try:
53
- toks = tokenizer.tokenize(text)
54
- except Exception:
55
- # 如果 tokenizer 不支持直接 tokenize raw text,可以先用 basic tokenization,如按空白分割
56
- toks = text.split()
57
- if not toks:
58
- continue
59
- token_seqs.append(toks)
60
- node = root
61
- node.count += 1
62
- for tok in toks:
63
- if tok not in node.children:
64
- node.children[tok] = TrieNode()
65
- node = node.children[tok]
66
- node.count += 1
67
-
68
- # 遍历 Trie,收集深度 <= max_levels 且有子节点的前缀序列
69
- results_prefix_tokens: set[tuple[str, ...]] = set()
70
-
71
- def dfs(node: TrieNode, path: list[str], depth: int):
72
- # node: 当前 TrieNode; path: 已走过的 token 列表; depth: len(path)
73
- if depth > max_levels:
74
- return
75
- # 如果当前节点有子节点,且 depth>0 (排除根节点本身),则为一个候选前缀
76
- if depth > 0 and node.children:
77
- results_prefix_tokens.add(tuple(path))
78
- # 继续往下遍历,直到 depth == max_levels
79
- if depth == max_levels:
80
- return
81
- for tok, child in node.children.items():
82
- path.append(tok)
83
- dfs(child, path, depth + 1)
84
- path.pop()
85
-
86
- dfs(root, [], 0)
87
-
88
- # 将 token 前缀转换回字符串
89
- cleaned: set[str] = set()
90
- for tok_prefix in results_prefix_tokens:
91
- try:
92
- # tokenizer.convert_tokens_to_string 在大多数 tokenizer 支持
93
- text_pref = tokenizer.convert_tokens_to_string(list(tok_prefix)).strip()
94
- except Exception:
95
- # fallback: 直接拼接 token(可能需要根据 tokenizer 规范加空格或直接连起来)
96
- text_pref = "".join(tok_prefix).strip()
97
- if text_pref:
98
- cleaned.add(text_pref)
99
-
100
- # 返回去重之后的列表
101
- return list(cleaned)
102
 
103
  @lru_cache(maxsize=8)
104
  def get_pipeline(model_name):
105
- global tokenizer
106
- tokenizer = AutoTokenizer.from_pretrained(model_name)
107
  mdl = AutoModelForCausalLM.from_pretrained(
108
  model_name, weights_only=False, trust_remote_code=True
109
  )
@@ -111,10 +38,10 @@ def get_pipeline(model_name):
111
  mdl.to("cuda")
112
  except Exception as e:
113
  print(f'Error: {e}')
114
- return pipeline("text-generation", model=mdl, tokenizer=tokenizer, device=0)
115
 
116
  @spaces.GPU
117
- def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty, max_prefix_levels=2):
118
  """
119
  使用 Diverse Beam Search 產生 m 條候選:
120
  - num_beams = m
@@ -131,7 +58,7 @@ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty, max
131
  "early_stopping": True,
132
  }
133
  if diversity_penalty and diversity_penalty > 0:
134
- valid_group = max(gcd(m, num_beam_groups),2)
135
  gen_kwargs["num_beam_groups"] = valid_group
136
  gen_kwargs["diversity_penalty"] = float(diversity_penalty)
137
 
@@ -146,7 +73,6 @@ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty, max
146
  converted = cc.convert(snippet).strip()
147
  suggestions.add(converted)
148
  suggestions = list(suggestions)
149
- suggestions = clean_suggestions(suggestions, max_prefix_levels)
150
 
151
  return update(choices=suggestions, value=None)
152
 
@@ -269,10 +195,6 @@ with gr.Blocks(css=custom_css) as demo:
269
  minimum=0.0, maximum=2.0, step=0.1, value=1.0,
270
  label="多樣性懲罰 (diversity_penalty)"
271
  )
272
- prefix_levels_slider = gr.Slider(
273
- minimum=1, maximum=5, step=1, value=2,
274
- label="Clean 前綴深度 (max_levels)"
275
- )
276
 
277
  # 綁定���件
278
  predict_button.click(
@@ -283,14 +205,13 @@ with gr.Blocks(css=custom_css) as demo:
283
  k_slider,
284
  m_slider,
285
  group_slider,
286
- diversity_penalty_slider,
287
- prefix_levels_slider # 新增
288
  ],
289
  outputs=suggestions,
290
  )
291
  input_text.change(
292
- fn=lambda txt, mdl, k, m, g, d, auto, pl: (
293
- suggest_next(txt, mdl, k, m, g, d, pl)
294
  if auto else update(choices=[], value=None)
295
  ),
296
  inputs=[
@@ -300,8 +221,7 @@ with gr.Blocks(css=custom_css) as demo:
300
  m_slider,
301
  group_slider,
302
  diversity_penalty_slider,
303
- auto_predict,
304
- prefix_levels_slider # 新增
305
  ],
306
  outputs=suggestions,
307
  )
 
10
 
11
  # 初始化簡體到繁體轉換器
12
  cc = OpenCC('s2t')
 
13
 
14
  # 可選模型列表
15
  MODEL_LIST = [
 
27
  "Epiculous/Violet_Twilight-v0.2",
28
  ]
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @lru_cache(maxsize=8)
32
  def get_pipeline(model_name):
33
+ tok = AutoTokenizer.from_pretrained(model_name)
 
34
  mdl = AutoModelForCausalLM.from_pretrained(
35
  model_name, weights_only=False, trust_remote_code=True
36
  )
 
38
  mdl.to("cuda")
39
  except Exception as e:
40
  print(f'Error: {e}')
41
+ return pipeline("text-generation", model=mdl, tokenizer=tok, device=0)
42
 
43
  @spaces.GPU
44
+ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
45
  """
46
  使用 Diverse Beam Search 產生 m 條候選:
47
  - num_beams = m
 
58
  "early_stopping": True,
59
  }
60
  if diversity_penalty and diversity_penalty > 0:
61
+ valid_group = gcd(m, num_beam_groups)
62
  gen_kwargs["num_beam_groups"] = valid_group
63
  gen_kwargs["diversity_penalty"] = float(diversity_penalty)
64
 
 
73
  converted = cc.convert(snippet).strip()
74
  suggestions.add(converted)
75
  suggestions = list(suggestions)
 
76
 
77
  return update(choices=suggestions, value=None)
78
 
 
195
  minimum=0.0, maximum=2.0, step=0.1, value=1.0,
196
  label="多樣性懲罰 (diversity_penalty)"
197
  )
 
 
 
 
198
 
199
  # 綁定���件
200
  predict_button.click(
 
205
  k_slider,
206
  m_slider,
207
  group_slider,
208
+ diversity_penalty_slider
 
209
  ],
210
  outputs=suggestions,
211
  )
212
  input_text.change(
213
+ fn=lambda txt, mdl, k, m, g, d, auto: (
214
+ suggest_next(txt, mdl, k, m, g, d)
215
  if auto else update(choices=[], value=None)
216
  ),
217
  inputs=[
 
221
  m_slider,
222
  group_slider,
223
  diversity_penalty_slider,
224
+ auto_predict
 
225
  ],
226
  outputs=suggestions,
227
  )