Luigi commited on
Commit
46d41f7
·
1 Parent(s): ab3ef88

add suggestion cleanning

Browse files
Files changed (1) hide show
  1. app.py +40 -119
app.py CHANGED
@@ -1,10 +1,9 @@
1
- # app.py
2
  import spaces
3
  import gradio as gr
4
  from gradio import update
5
  from functools import lru_cache
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
- from opencc import OpenCC # 用於簡體轉繁體
8
 
9
  # 初始化簡體到繁體轉換器
10
  cc = OpenCC('s2t')
@@ -26,30 +25,19 @@ MODEL_LIST = [
26
  ]
27
 
28
  def merge_common_prefixes(suggestions, min_len=2):
29
- """
30
- 合併具有共同前綴的建議:
31
- - 找出所有長度 ≥ min_len 的共同前綴
32
- - 將這些前綴作為新建議,移除原有被合併的項目
33
- """
34
  prefixes = []
35
  to_remove = set()
36
-
37
  for i in range(len(suggestions)):
38
  for j in range(i+1, len(suggestions)):
39
  s1, s2 = suggestions[i], suggestions[j]
40
- # 計算字元級共同前綴
41
  common = ''.join(c1 for c1, c2 in zip(s1, s2) if c1 == c2)
42
  if len(common) >= min_len:
43
  prefixes.append(common)
44
  to_remove.update([s1, s2])
45
-
46
- # 去重前綴
47
  unique_prefixes = []
48
  for p in prefixes:
49
  if p not in unique_prefixes:
50
  unique_prefixes.append(p)
51
-
52
- # 剩下沒有被合併的建議
53
  remainder = [s for s in suggestions if s not in to_remove]
54
  return unique_prefixes + remainder
55
 
@@ -64,14 +52,7 @@ def get_pipeline(model_name):
64
 
65
  @spaces.GPU
66
  def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
67
- """
68
- 使用 Diverse Beam Search 產生 m 條候選:
69
- - num_beams = m
70
- - num_beam_groups, diversity_penalty 可調整多樣性
71
- 之後轉繁體、去重、合併共同前綴後回傳。
72
- """
73
  gen_pipe = get_pipeline(model_name)
74
- # 構造 generate 參數字典,僅在 penalty>0 時加入 diversity 相關
75
  gen_kwargs = {
76
  "max_new_tokens": k,
77
  "num_beams": m,
@@ -81,110 +62,54 @@ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
81
  }
82
  if diversity_penalty and diversity_penalty > 0:
83
  gen_kwargs["num_beam_groups"] = num_beam_groups
84
- gen_kwargs["diversity_penalty"] = float(diversity_penalty)
85
 
86
  outs = gen_pipe(text, **gen_kwargs)
87
 
88
- # 提取純下文、過濾空字串、繁體化
89
- suggestions = [
90
- cc.convert(out["generated_text"][len(text):].strip())
91
- for out in outs
92
- if out["generated_text"][len(text):].strip()
93
- ]
 
 
94
 
95
- # 去重
96
  unique_suggestions = []
97
- for s in suggestions:
98
- if s not in unique_suggestions:
99
- unique_suggestions.append(s)
 
 
 
100
 
101
  # 合併共同前綴
102
- final_suggestions = merge_common_prefixes(unique_suggestions, min_len=2)
103
 
104
- return update(choices=final_suggestions, value=None)
 
 
 
 
 
 
 
105
 
 
106
 
107
- def append_suggestion(current, choice):
108
- if choice is None:
109
- return current
110
- # 直接插入選中的候選文字
111
- return current + choice
112
 
113
- # 自訂 CSS:模擬經典中文輸入法候選欄樣式,並優化手機響應與自動高度
114
- custom_css = """
115
- #suggestions-bar {
116
- width: 100%;
117
- margin-bottom: 8px;
118
- }
119
  #suggestions-bar .candidate-list {
120
- display: flex;
121
- gap: 8px;
122
- background: #fff;
123
- border: 1px solid #999;
124
- border-radius: 4px;
125
- padding: 6px;
126
- overflow-x: auto;
127
- white-space: nowrap;
128
- }
129
- #suggestions-bar .candidate-list label {
130
- cursor: pointer;
131
- padding: 6px 10px;
132
- font-size: 16px;
133
- }
134
- #suggestions-bar .candidate-list label:hover {
135
- background: #f5f5f5;
136
- }
137
- #suggestions-bar .candidate-list input[type=radio]:checked + label {
138
- background: #e6f7ff;
139
- border: 1px solid #1890ff;
140
- }
141
- #input-box textarea {
142
- width: 100%;
143
- font-size: 16px;
144
- padding: 6px;
145
- box-sizing: border-box;
146
- overflow: hidden;
147
- resize: none;
148
  }
149
- #predict-button {
150
- margin-top: 8px;
151
- width: 100%;
152
- }
153
- /* 手機響應式 */
154
- @media only screen and (max-width: 600px) {
155
- #suggestions-bar .candidate-list label {
156
- padding: 8px;
157
- font-size: 18px;
158
- }
159
- #predict-button {
160
- font-size: 18px;
161
- }
162
- }
163
- """
164
-
165
- # 自動增高腳本
166
- auto_height_js = """
167
- <script>
168
- window.addEventListener('load', () => {
169
- const textarea = document.querySelector('#input-box textarea');
170
- if (!textarea) return;
171
- textarea.style.height = 'auto';
172
- textarea.addEventListener('input', function() {
173
- this.style.height = 'auto';
174
- this.style.height = this.scrollHeight + 'px';
175
- });
176
- });
177
- </script>
178
- """
179
-
180
- with gr.Blocks(css=custom_css) as demo:
181
- gr.HTML(auto_height_js)
182
- gr.Markdown(
183
- "## 🇹🇼 繁體中文 IME 加速器 \
184
- "
185
- "結合小型語言模型與 ZeroGPU,提供即時輸入法風格候選欄。"
186
- )
187
-
188
  with gr.Column():
189
  suggestions = gr.Radio(
190
  [], label="", interactive=True, type="value",
@@ -195,14 +120,11 @@ with gr.Blocks(css=custom_css) as demo:
195
  lines=1, max_lines=20, elem_id="input-box"
196
  )
197
 
198
- # 永遠顯示預測按鈕
199
  with gr.Row():
200
  auto_predict = gr.Checkbox(
201
  value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
202
  )
203
- predict_button = gr.Button(
204
- "預測", elem_id="predict-button"
205
- )
206
 
207
  with gr.Accordion("進階設定", open=False):
208
  model_selector = gr.Dropdown(
@@ -215,15 +137,14 @@ with gr.Blocks(css=custom_css) as demo:
215
  minimum=1, maximum=30, step=1, value=30, label="M(建議數/Beam 數)"
216
  )
217
  group_slider = gr.Slider(
218
- minimum=1, maximum=30, step=1, value=30,
219
  label="Beam 群組數 (num_beam_groups)"
220
  )
221
  diversity_penalty_slider = gr.Slider(
222
- minimum=0.0, maximum=2.0, step=0.1, value=1.0,
223
  label="多樣性懲罰 (diversity_penalty)"
224
  )
225
 
226
- # 綁定事件
227
  predict_button.click(
228
  fn=suggest_next,
229
  inputs=[
@@ -258,4 +179,4 @@ with gr.Blocks(css=custom_css) as demo:
258
  outputs=input_text,
259
  )
260
 
261
- demo.launch()
 
 
1
  import spaces
2
  import gradio as gr
3
  from gradio import update
4
  from functools import lru_cache
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
+ from opencc import OpenCC # 用於簡體到繁體轉換
7
 
8
  # 初始化簡體到繁體轉換器
9
  cc = OpenCC('s2t')
 
25
  ]
26
 
27
  def merge_common_prefixes(suggestions, min_len=2):
 
 
 
 
 
28
  prefixes = []
29
  to_remove = set()
 
30
  for i in range(len(suggestions)):
31
  for j in range(i+1, len(suggestions)):
32
  s1, s2 = suggestions[i], suggestions[j]
 
33
  common = ''.join(c1 for c1, c2 in zip(s1, s2) if c1 == c2)
34
  if len(common) >= min_len:
35
  prefixes.append(common)
36
  to_remove.update([s1, s2])
 
 
37
  unique_prefixes = []
38
  for p in prefixes:
39
  if p not in unique_prefixes:
40
  unique_prefixes.append(p)
 
 
41
  remainder = [s for s in suggestions if s not in to_remove]
42
  return unique_prefixes + remainder
43
 
 
52
 
53
  @spaces.GPU
54
  def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
 
 
 
 
 
 
55
  gen_pipe = get_pipeline(model_name)
 
56
  gen_kwargs = {
57
  "max_new_tokens": k,
58
  "num_beams": m,
 
62
  }
63
  if diversity_penalty and diversity_penalty > 0:
64
  gen_kwargs["num_beam_groups"] = num_beam_groups
65
+ gen_kwargs["diversity_penalty"] = diversity_penalty
66
 
67
  outs = gen_pipe(text, **gen_kwargs)
68
 
69
+ # 提取純下文、過濾空字串、繁體化、確保 strip 處理
70
+ raw_suggestions = []
71
+ for out in outs:
72
+ snippet = out["generated_text"][len(text):].strip()
73
+ if not snippet:
74
+ continue
75
+ converted = cc.convert(snippet).strip()
76
+ raw_suggestions.append(converted)
77
 
78
+ # 去重 (基於 strip 後內容)
79
  unique_suggestions = []
80
+ seen = set()
81
+ for s in raw_suggestions:
82
+ key = s
83
+ if key not in seen:
84
+ seen.add(key)
85
+ unique_suggestions.append(key)
86
 
87
  # 合併共同前綴
88
+ merged_prefixes = merge_common_prefixes(unique_suggestions, min_len=2)
89
 
90
+ # 最終去重並移除空項 (基於 strip 後內容)
91
+ final_suggestions = []
92
+ seen_final = set()
93
+ for s in merged_prefixes:
94
+ key = s.strip()
95
+ if key and key not in seen_final:
96
+ seen_final.add(key)
97
+ final_suggestions.append(key)
98
 
99
+ return update(choices=final_suggestions, value=None)
100
 
101
+ def append_suggestion(text, choice):
102
+ return text + choice
 
 
 
103
 
104
+ with gr.Blocks(css="""
105
+ #suggestions-bar { width: 100%; margin-bottom: 8px; }
 
 
 
 
106
  #suggestions-bar .candidate-list {
107
+ display: flex; gap: 8px; background: #fff;
108
+ border: 1px solid #999; border-radius: 4px;
109
+ padding: 6px; overflow-x: auto; white-space: nowrap;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  }
111
+ #suggestions-bar .candidate-list label { cursor: pointer; }
112
+ """) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  with gr.Column():
114
  suggestions = gr.Radio(
115
  [], label="", interactive=True, type="value",
 
120
  lines=1, max_lines=20, elem_id="input-box"
121
  )
122
 
 
123
  with gr.Row():
124
  auto_predict = gr.Checkbox(
125
  value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
126
  )
127
+ predict_button = gr.Button("預測", elem_id="predict-button")
 
 
128
 
129
  with gr.Accordion("進階設定", open=False):
130
  model_selector = gr.Dropdown(
 
137
  minimum=1, maximum=30, step=1, value=30, label="M(建議數/Beam 數)"
138
  )
139
  group_slider = gr.Slider(
140
+ minimum=1, maximum=30, step=1, value=5,
141
  label="Beam 群組數 (num_beam_groups)"
142
  )
143
  diversity_penalty_slider = gr.Slider(
144
+ minimum=0.0, maximum=2.0, step=0.1, value=0.3,
145
  label="多樣性懲罰 (diversity_penalty)"
146
  )
147
 
 
148
  predict_button.click(
149
  fn=suggest_next,
150
  inputs=[
 
179
  outputs=input_text,
180
  )
181
 
182
+ demo.launch()