File size: 9,801 Bytes
6d98a7c
dc13b6d
 
f358820
dc13b6d
 
6d98a7c
05df099
8b84de4
2f31c84
 
 
a94f020
dc13b6d
da40aec
dc13b6d
e894885
812c80c
2a9fd77
cc1322f
dc13b6d
 
 
 
 
42dc39c
 
 
dc13b6d
 
a94f020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c6147f
 
dc13b6d
a94f020
 
2f31c84
 
 
6eb5ae5
 
 
 
a94f020
dc13b6d
 
a94f020
6d98a7c
 
 
 
 
 
da40aec
6d98a7c
3350989
 
 
 
 
 
 
 
a94f020
05df099
6ec4350
3350989
 
 
216d8ce
4a676fe
216d8ce
 
 
 
 
4a676fe
 
a94f020
3350989
4a676fe
dc13b6d
3b87f8e
6d98a7c
 
 
 
 
 
 
 
 
 
 
 
44ac7e8
6d98a7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44ac7e8
6d98a7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109cc2f
 
 
 
 
 
 
ee21d1a
109cc2f
dc13b6d
6d98a7c
109cc2f
 
 
 
6d98a7c
 
 
041841e
44ac7e8
da40aec
44ac7e8
da40aec
 
a4c33c1
da40aec
 
ab41c65
da40aec
9c6147f
ab41c65
9c6147f
 
 
6d98a7c
9c6147f
 
a94f020
 
 
 
dc13b6d
6d98a7c
041841e
dc13b6d
9c6147f
 
 
 
 
 
a94f020
 
9c6147f
dc13b6d
 
6277588
a94f020
 
9c6147f
 
 
 
 
 
 
 
 
a94f020
 
9c6147f
6277588
 
dc13b6d
 
 
 
 
 
6d98a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# app.py
import spaces
import gradio as gr
from gradio import update
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from opencc import OpenCC  # 用於簡體轉繁體
from math import gcd
from termcolor import cprint

# 初始化簡體到繁體轉換器
cc = OpenCC('s2t')
tokenizer = None

# 可選模型列表
MODEL_LIST = [
    "liswei/Taiwan-ELM-270M",
    "Mxode/SmolLM-Chinese-180M",
    "flyingfishinwater/chinese-baby-llama2",
    "unsloth/gemma-3-1b-pt",
    "ckiplab/gpt2-tiny-chinese",
    "ckiplab/gpt2-base-chinese",
    "liswei/Taiwan-ELM-1_1B",
    "benchang1110/Qwen2.5-Taiwan-1.5B-Instruct",
    "benchang1110/Taiwan-tinyllama-v1.0-base",
    "lianghsun/Llama-3.2-Taiwan-3B",
    "twinkle-ai/Llama-3.2-3B-F1-Instruct",
    "Epiculous/Violet_Twilight-v0.2",
]

def clean_suggestions(suggestions: list[str], max_levels: int) -> list[str]:
    """
    清洗建议列表:
    1. 对每条建议用 tokenizer.tokenize 得到 token 序列。
    2. 构建前缀树,将所有 token 序列插入。
    3. 遍历前缀树,仅在深度 <= max_levels 且该节点有子节点时,提取对应 token 前缀。
    4. 将这些 token 前缀转换回文本并去重,返回列表。
    """
    # 定义 Trie 节点结构
    class TrieNode:
        __slots__ = ("children", "count")
        def __init__(self):
            self.children: dict[str, TrieNode] = {}
            self.count: int = 0  # 可以记录有多少序列经过此节点(可选)

    # 构建前缀树
    root = TrieNode()
    token_seqs: list[list[str]] = []

    for text in suggestions:
        # tokenizer.tokenize 可能返回子词 token 列表
        try:
            toks = tokenizer.tokenize(text)
        except Exception:
            # 如果 tokenizer 不支持直接 tokenize raw text,可以先用 basic tokenization,如按空白分割
            toks = text.split()
        if not toks:
            continue
        token_seqs.append(toks)
        node = root
        node.count += 1
        for tok in toks:
            if tok not in node.children:
                node.children[tok] = TrieNode()
            node = node.children[tok]
            node.count += 1

    # 遍历 Trie,收集深度 <= max_levels 且有子节点的前缀序列
    results_prefix_tokens: set[tuple[str, ...]] = set()

    def dfs(node: TrieNode, path: list[str], depth: int):
        # node: 当前 TrieNode; path: 已走过的 token 列表; depth: len(path)
        if depth > max_levels:
            return
        # 如果当前节点有子节点,且 depth>0 (排除根节点本身),则为一个候选前缀
        if depth > 0 and node.children:
            results_prefix_tokens.add(tuple(path))
        # 继续往下遍历,直到 depth == max_levels
        if depth == max_levels:
            return
        for tok, child in node.children.items():
            path.append(tok)
            dfs(child, path, depth + 1)
            path.pop()

    dfs(root, [], 0)

    # 将 token 前缀转换回字符串
    cleaned: set[str] = set()
    for tok_prefix in results_prefix_tokens:
        try:
            # tokenizer.convert_tokens_to_string 在大多数 tokenizer 支持
            text_pref = tokenizer.convert_tokens_to_string(list(tok_prefix)).strip()
        except Exception:
            # fallback: 直接拼接 token(可能需要根据 tokenizer 规范加空格或直接连起来)
            text_pref = "".join(tok_prefix).strip()
        if text_pref:
            cleaned.add(text_pref)

    # 返回去重之后的列表
    return list(cleaned)

@lru_cache(maxsize=8)
def get_pipeline(model_name):
    global tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    mdl = AutoModelForCausalLM.from_pretrained(
        model_name, weights_only=False, trust_remote_code=True
    )
    try:
        mdl.to("cuda")
    except Exception as e:
        print(f'Error: {e}')
    return pipeline("text-generation", model=mdl, tokenizer=tokenizer, device=0)

@spaces.GPU
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty, max_prefix_levels=2):
    """
    使用 Diverse Beam Search 產生 m 條候選:
     - num_beams = m
     - num_beam_groups, diversity_penalty 可調整多樣性
    之後轉繁體、去重、合併共同前綴後回傳。
    """
    gen_pipe = get_pipeline(model_name)
    # 構造 generate 參數字典,僅在 penalty>0 時加入 diversity 相關
    gen_kwargs = {
        "max_new_tokens": k,
        "num_beams": m,
        "num_return_sequences": m,
        "do_sample": False,
        "early_stopping": True,
    }
    if diversity_penalty and diversity_penalty > 0:
        valid_group = max(gcd(m, num_beam_groups),2)
        gen_kwargs["num_beam_groups"] = valid_group
        gen_kwargs["diversity_penalty"] = float(diversity_penalty)

    outs = gen_pipe(text, **gen_kwargs)

    # 提取純下文、過濾空字串、繁體化、確保 strip 處理
    suggestions = set()
    for out in outs:
        snippet = out["generated_text"][len(text):].strip()
        if not snippet:
            continue
        converted = cc.convert(snippet).strip()
        suggestions.add(converted)
    suggestions = list(suggestions)
    suggestions = clean_suggestions(suggestions, max_prefix_levels)

    return update(choices=suggestions, value=None)


def append_suggestion(current, choice):
    if choice is None:
        return current
    # 直接插入選中的候選文字
    return current + choice

# 自訂 CSS:模擬經典中文輸入法候選欄樣式,並優化手機響應與自動高度
custom_css = """
#suggestions-bar {
    width: 100%;
    margin-bottom: 8px;
}
#suggestions-bar .candidate-list {
    display: flex;
    gap: 8px;
    background: #fff;
    border: 1px solid #999;
    border-radius: 4px;
    padding: 6px;
    overflow-x: auto;
    white-space: nowrap;
}
#suggestions-bar .candidate-list label {
    cursor: pointer;
    padding: 6px 10px;
    font-size: 16px;
}
#suggestions-bar .candidate-list label:hover {
    background: #f5f5f5;
}
#suggestions-bar .candidate-list input[type=radio]:checked + label {
    background: #e6f7ff;
    border: 1px solid #1890ff;
}
#input-box textarea {
    width: 100%;
    font-size: 16px;
    padding: 6px;
    box-sizing: border-box;
    overflow: hidden;
    resize: none;
}
#predict-button {
    margin-top: 8px;
    width: 100%;
}
/* 手機響應式 */
@media only screen and (max-width: 600px) {
    #suggestions-bar .candidate-list label {
        padding: 8px;
        font-size: 18px;
    }
    #predict-button {
        font-size: 18px;
    }
}
"""

# 自動增高腳本
auto_height_js = """
<script>
  window.addEventListener('load', () => {
    const textarea = document.querySelector('#input-box textarea');
    if (!textarea) return;
    textarea.style.height = 'auto';
    textarea.addEventListener('input', function() {
      this.style.height = 'auto';
      this.style.height = this.scrollHeight + 'px';
    });
  });
</script>
"""

with gr.Blocks(css=custom_css) as demo:
    gr.HTML(auto_height_js)
    gr.Markdown(
        "## 🇹🇼 繁體中文 IME 加速器  \
"
        "結合小型語言模型與 ZeroGPU,提供即時輸入法風格候選欄。"
    )

    with gr.Column():
        suggestions = gr.Radio(
            [], label="", interactive=True, type="value",
            elem_id="suggestions-bar", elem_classes="candidate-list"
        )
        input_text = gr.Textbox(
            label="", placeholder="請輸入拼音或文字…",
            lines=1, max_lines=20, elem_id="input-box"
        )

    # 永遠顯示預測按鈕
    with gr.Row():
        auto_predict = gr.Checkbox(
            value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
        )
        predict_button = gr.Button(
            "預測", elem_id="predict-button"
        )

    with gr.Accordion("進階設定", open=False):
        model_selector = gr.Dropdown(
            MODEL_LIST, value=MODEL_LIST[0], label="模型"
        )
        k_slider = gr.Slider(
            minimum=1, maximum=50, step=1, value=10, label="K(最大新詞元數)"
        )
        m_slider = gr.Slider(
            minimum=1, maximum=30, step=1, value=20, label="M(建議數/Beam 數)"
        )
        group_slider = gr.Slider(
            minimum=1, maximum=30, step=1, value=6,
            label="Beam 群組數 (num_beam_groups)"
        )
        diversity_penalty_slider = gr.Slider(
            minimum=0.0, maximum=2.0, step=0.1, value=1.0,
            label="多樣性懲罰 (diversity_penalty)"
        )
        prefix_levels_slider = gr.Slider(
            minimum=1, maximum=5, step=1, value=2,
            label="Clean 前綴深度 (max_levels)"
        )

    # 綁定事件
    predict_button.click(
        fn=suggest_next,
        inputs=[
            input_text,
            model_selector,
            k_slider,
            m_slider,
            group_slider,
            diversity_penalty_slider,
            prefix_levels_slider  # 新增
        ],
        outputs=suggestions,
    )
    input_text.change(
        fn=lambda txt, mdl, k, m, g, d, auto, pl: (
            suggest_next(txt, mdl, k, m, g, d, pl)
            if auto else update(choices=[], value=None)
        ),
        inputs=[
            input_text,
            model_selector,
            k_slider,
            m_slider,
            group_slider,
            diversity_penalty_slider,
            auto_predict,
            prefix_levels_slider  # 新增
        ],
        outputs=suggestions,
    )
    suggestions.change(
        fn=append_suggestion,
        inputs=[input_text, suggestions],
        outputs=input_text,
    )

    demo.launch()