File size: 8,238 Bytes
6d98a7c
dc13b6d
 
f358820
dc13b6d
 
6d98a7c
05df099
8b84de4
2f31c84
 
 
dc13b6d
da40aec
dc13b6d
e894885
812c80c
2a9fd77
cc1322f
dc13b6d
 
 
 
 
42dc39c
 
 
dc13b6d
 
9c6147f
6d98a7c
 
 
 
 
9c6147f
 
6d98a7c
9c6147f
 
 
6d98a7c
9c6147f
 
 
 
6d98a7c
 
9c6147f
 
 
 
6d98a7c
 
9c6147f
 
 
 
dc13b6d
 
2f31c84
 
 
6eb5ae5
 
 
 
dc13b6d
 
 
9c6147f
6d98a7c
 
 
 
 
 
da40aec
6d98a7c
3350989
 
 
 
 
 
 
 
c6bc247
05df099
6ec4350
3350989
 
 
216d8ce
 
 
 
 
 
 
 
3350989
216d8ce
fe26a7b
216d8ce
 
 
 
 
 
9c6147f
 
6c195b9
b237d33
6c195b9
216d8ce
 
 
 
 
 
 
 
 
72f7051
46d41f7
dc13b6d
3b87f8e
6d98a7c
 
 
 
 
 
 
 
 
 
 
 
44ac7e8
6d98a7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44ac7e8
6d98a7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109cc2f
 
 
 
 
 
 
ee21d1a
109cc2f
dc13b6d
6d98a7c
109cc2f
 
 
 
6d98a7c
 
 
041841e
44ac7e8
da40aec
44ac7e8
da40aec
 
a4c33c1
da40aec
 
ab41c65
da40aec
9c6147f
ab41c65
9c6147f
 
 
6d98a7c
9c6147f
 
dc13b6d
6d98a7c
041841e
dc13b6d
9c6147f
 
 
 
 
 
 
 
dc13b6d
 
6277588
9c6147f
 
 
 
 
 
 
 
 
 
 
 
 
6277588
 
dc13b6d
 
 
 
 
 
6d98a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
# app.py
import spaces
import gradio as gr
from gradio import update
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from opencc import OpenCC  # 用於簡體轉繁體
from math import gcd
from termcolor import cprint

# 初始化簡體到繁體轉換器
cc = OpenCC('s2t')

# 可選模型列表
MODEL_LIST = [
    "liswei/Taiwan-ELM-270M",
    "Mxode/SmolLM-Chinese-180M",
    "flyingfishinwater/chinese-baby-llama2",
    "unsloth/gemma-3-1b-pt",
    "ckiplab/gpt2-tiny-chinese",
    "ckiplab/gpt2-base-chinese",
    "liswei/Taiwan-ELM-1_1B",
    "benchang1110/Qwen2.5-Taiwan-1.5B-Instruct",
    "benchang1110/Taiwan-tinyllama-v1.0-base",
    "lianghsun/Llama-3.2-Taiwan-3B",
    "twinkle-ai/Llama-3.2-3B-F1-Instruct",
    "Epiculous/Violet_Twilight-v0.2",
]

def merge_common_prefixes(suggestions, min_len=2):
    """
    合併具有共同前綴的建議:
    - 找出所有長度 ≥ min_len 的共同前綴
    - 將這些前綴作為新建議,移除原有被合併的項目
    """
    prefixes = []
    to_remove = set()

    for i in range(len(suggestions)):
        for j in range(i+1, len(suggestions)):
            s1, s2 = suggestions[i], suggestions[j]
            # 計算字元級共同前綴
            common = ''.join(c1 for c1, c2 in zip(s1, s2) if c1 == c2)
            if len(common) >= min_len:
                prefixes.append(common)
                to_remove.update([s1, s2])

    # 去重前綴
    unique_prefixes = []
    for p in prefixes:
        if p not in unique_prefixes:
            unique_prefixes.append(p)

    # 剩下沒有被合併的建議
    remainder = [s for s in suggestions if s not in to_remove]
    return unique_prefixes + remainder

@lru_cache(maxsize=8)
def get_pipeline(model_name):
    tok = AutoTokenizer.from_pretrained(model_name)
    mdl = AutoModelForCausalLM.from_pretrained(
        model_name, weights_only=False, trust_remote_code=True
    )
    try:
        mdl.to("cuda")
    except Exception as e:
        print(f'Error: {e}')
    return pipeline("text-generation", model=mdl, tokenizer=tok, device=0)

@spaces.GPU
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
    """
    使用 Diverse Beam Search 產生 m 條候選:
     - num_beams = m
     - num_beam_groups, diversity_penalty 可調整多樣性
    之後轉繁體、去重、合併共同前綴後回傳。
    """
    gen_pipe = get_pipeline(model_name)
    # 構造 generate 參數字典,僅在 penalty>0 時加入 diversity 相關
    gen_kwargs = {
        "max_new_tokens": k,
        "num_beams": m,
        "num_return_sequences": m,
        "do_sample": False,
        "early_stopping": True,
    }
    if diversity_penalty and diversity_penalty > 0:
        valid_group = gcd(m, num_beam_groups)
        gen_kwargs["num_beam_groups"] = valid_group
        gen_kwargs["diversity_penalty"] = float(diversity_penalty)

    outs = gen_pipe(text, **gen_kwargs)

    # 提取純下文、過濾空字串、繁體化、確保 strip 處理
    raw_suggestions = []
    for out in outs:
        snippet = out["generated_text"][len(text):].strip()
        if not snippet:
            continue
        converted = cc.convert(snippet).strip()
        raw_suggestions.append(converted)

    # 去重 (基於 strip 後內容)
    unique_suggestions = []
    seen = set()
    for s in raw_suggestions:
        key = s
        if key not in seen:
            seen.add(key)
            unique_suggestions.append(key)

    # 合併共同前綴
    cprint(f'unique_suggestions: {unique_suggestions}','blue')
    merged_prefixes = merge_common_prefixes(unique_suggestions, min_len=1)
    cprint(f"merged_prefixes: {merged_prefixes}",'red')

    # 最終去重並移除空項 (基於 strip 後內容)
    final_suggestions = []
    seen_final = set()
    for s in merged_prefixes:
        key = s.strip()
        if key and key not in seen_final:
            seen_final.add(key)
            final_suggestions.append(key)

    return update(choices=final_suggestions, value=None)


def append_suggestion(current, choice):
    if choice is None:
        return current
    # 直接插入選中的候選文字
    return current + choice

# 自訂 CSS:模擬經典中文輸入法候選欄樣式,並優化手機響應與自動高度
custom_css = """
#suggestions-bar {
    width: 100%;
    margin-bottom: 8px;
}
#suggestions-bar .candidate-list {
    display: flex;
    gap: 8px;
    background: #fff;
    border: 1px solid #999;
    border-radius: 4px;
    padding: 6px;
    overflow-x: auto;
    white-space: nowrap;
}
#suggestions-bar .candidate-list label {
    cursor: pointer;
    padding: 6px 10px;
    font-size: 16px;
}
#suggestions-bar .candidate-list label:hover {
    background: #f5f5f5;
}
#suggestions-bar .candidate-list input[type=radio]:checked + label {
    background: #e6f7ff;
    border: 1px solid #1890ff;
}
#input-box textarea {
    width: 100%;
    font-size: 16px;
    padding: 6px;
    box-sizing: border-box;
    overflow: hidden;
    resize: none;
}
#predict-button {
    margin-top: 8px;
    width: 100%;
}
/* 手機響應式 */
@media only screen and (max-width: 600px) {
    #suggestions-bar .candidate-list label {
        padding: 8px;
        font-size: 18px;
    }
    #predict-button {
        font-size: 18px;
    }
}
"""

# 自動增高腳本
auto_height_js = """
<script>
  window.addEventListener('load', () => {
    const textarea = document.querySelector('#input-box textarea');
    if (!textarea) return;
    textarea.style.height = 'auto';
    textarea.addEventListener('input', function() {
      this.style.height = 'auto';
      this.style.height = this.scrollHeight + 'px';
    });
  });
</script>
"""

with gr.Blocks(css=custom_css) as demo:
    gr.HTML(auto_height_js)
    gr.Markdown(
        "## 🇹🇼 繁體中文 IME 加速器  \
"
        "結合小型語言模型與 ZeroGPU,提供即時輸入法風格候選欄。"
    )

    with gr.Column():
        suggestions = gr.Radio(
            [], label="", interactive=True, type="value",
            elem_id="suggestions-bar", elem_classes="candidate-list"
        )
        input_text = gr.Textbox(
            label="", placeholder="請輸入拼音或文字…",
            lines=1, max_lines=20, elem_id="input-box"
        )

    # 永遠顯示預測按鈕
    with gr.Row():
        auto_predict = gr.Checkbox(
            value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
        )
        predict_button = gr.Button(
            "預測", elem_id="predict-button"
        )

    with gr.Accordion("進階設定", open=False):
        model_selector = gr.Dropdown(
            MODEL_LIST, value=MODEL_LIST[0], label="模型"
        )
        k_slider = gr.Slider(
            minimum=1, maximum=50, step=1, value=10, label="K(最大新詞元數)"
        )
        m_slider = gr.Slider(
            minimum=1, maximum=30, step=1, value=20, label="M(建議數/Beam 數)"
        )
        group_slider = gr.Slider(
            minimum=1, maximum=30, step=1, value=6,
            label="Beam 群組數 (num_beam_groups)"
        )
        diversity_penalty_slider = gr.Slider(
            minimum=0.0, maximum=2.0, step=0.1, value=1.0,
            label="多樣性懲罰 (diversity_penalty)"
        )

    # 綁定事件
    predict_button.click(
        fn=suggest_next,
        inputs=[
            input_text,
            model_selector,
            k_slider,
            m_slider,
            group_slider,
            diversity_penalty_slider
        ],
        outputs=suggestions,
    )
    input_text.change(
        fn=lambda txt, mdl, k, m, g, d, auto: (
            suggest_next(txt, mdl, k, m, g, d)
            if auto else update(choices=[], value=None)
        ),
        inputs=[
            input_text,
            model_selector,
            k_slider,
            m_slider,
            group_slider,
            diversity_penalty_slider,
            auto_predict
        ],
        outputs=suggestions,
    )
    suggestions.change(
        fn=append_suggestion,
        inputs=[input_text, suggestions],
        outputs=input_text,
    )

    demo.launch()