File size: 5,432 Bytes
9c6147f
dc13b6d
 
f358820
dc13b6d
 
2f31c84
 
 
 
dc13b6d
da40aec
dc13b6d
e894885
812c80c
2a9fd77
cc1322f
dc13b6d
 
 
 
 
42dc39c
 
 
dc13b6d
 
9c6147f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc13b6d
 
2f31c84
 
 
dc13b6d
 
 
 
9c6147f
da40aec
3350989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe26a7b
 
 
 
9c6147f
 
 
 
 
72f7051
3350989
 
dc13b6d
3350989
 
44ac7e8
3350989
 
 
44ac7e8
3350989
 
109cc2f
 
 
 
 
 
 
ee21d1a
109cc2f
dc13b6d
109cc2f
 
 
 
3350989
041841e
44ac7e8
da40aec
44ac7e8
da40aec
 
a4c33c1
da40aec
 
a4c33c1
da40aec
9c6147f
 
 
 
 
 
 
 
dc13b6d
041841e
dc13b6d
9c6147f
 
 
 
 
 
 
 
dc13b6d
 
6277588
9c6147f
 
 
 
 
 
 
 
 
 
 
 
 
6277588
 
dc13b6d
 
 
 
 
 
3350989
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# app.py
import spaces
import gradio as gr
from gradio import update
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from opencc import OpenCC  # 用於簡體轉繁體

# 初始化簡體到繁體轉換器
cc = OpenCC('s2t')

# 可選模型列表
MODEL_LIST = [
    "liswei/Taiwan-ELM-270M",
    "Mxode/SmolLM-Chinese-180M",
    "flyingfishinwater/chinese-baby-llama2",
    "unsloth/gemma-3-1b-pt",
    "ckiplab/gpt2-tiny-chinese",
    "ckiplab/gpt2-base-chinese",
    "liswei/Taiwan-ELM-1_1B",
    "benchang1110/Qwen2.5-Taiwan-1.5B-Instruct",
    "benchang1110/Taiwan-tinyllama-v1.0-base",
    "lianghsun/Llama-3.2-Taiwan-3B",
    "twinkle-ai/Llama-3.2-3B-F1-Instruct",
    "Epiculous/Violet_Twilight-v0.2",
]

def merge_common_prefixes(suggestions, min_len=2):
    prefixes = []
    to_remove = set()
    for i in range(len(suggestions)):
        for j in range(i+1, len(suggestions)):
            s1, s2 = suggestions[i], suggestions[j]
            common = ''.join(c1 for c1, c2 in zip(s1, s2) if c1 == c2)
            if len(common) >= min_len:
                prefixes.append(common)
                to_remove.update([s1, s2])
    unique_prefixes = []
    for p in prefixes:
        if p not in unique_prefixes:
            unique_prefixes.append(p)
    remainder = [s for s in suggestions if s not in to_remove]
    return unique_prefixes + remainder

@lru_cache(maxsize=8)
def get_pipeline(model_name):
    tok = AutoTokenizer.from_pretrained(model_name)
    mdl = AutoModelForCausalLM.from_pretrained(
        model_name, weights_only=False, trust_remote_code=True
    )
    mdl.to("cuda")
    return pipeline("text-generation", model=mdl, tokenizer=tok, device=0)

@spaces.GPU
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
    gen_pipe = get_pipeline(model_name)
    # 構造 generate 參數字典,僅在 penalty>0 時加入 diversity 相關
    gen_kwargs = {
        "max_new_tokens": k,
        "num_beams": m,
        "num_return_sequences": m,
        "do_sample": False,
        "early_stopping": True,
    }
    if diversity_penalty and diversity_penalty > 0:
        gen_kwargs["num_beam_groups"] = num_beam_groups
        gen_kwargs["diversity_penalty"] = diversity_penalty

    outs = gen_pipe(text, **gen_kwargs)

    # 提取純下文、過濾空字串、繁體化
    suggestions = [
        cc.convert(out["generated_text"][len(text):].strip())
        for out in outs
        if out["generated_text"][len(text):].strip()
    ]

    # 去重
    unique_suggestions = []
    for s in suggestions:
        if s not in unique_suggestions:
            unique_suggestions.append(s)

    # 合併共同前綴
    final_suggestions = merge_common_prefixes(unique_suggestions, min_len=2)

    return update(choices=final_suggestions, value=None)

def append_suggestion(text, choice):
    return text + choice

with gr.Blocks(css="""
#suggestions-bar { width: 100%; margin-bottom: 8px; }
#suggestions-bar .candidate-list {
    display: flex; gap: 8px; background: #fff;
    border: 1px solid #999; border-radius: 4px;
    padding: 6px; overflow-x: auto; white-space: nowrap;
}
#suggestions-bar .candidate-list label { cursor: pointer; }
""") as demo:
    with gr.Column():
        suggestions = gr.Radio(
            [], label="", interactive=True, type="value",
            elem_id="suggestions-bar", elem_classes="candidate-list"
        )
        input_text = gr.Textbox(
            label="", placeholder="請輸入拼音或文字…",
            lines=1, max_lines=20, elem_id="input-box"
        )

    with gr.Row():
        auto_predict = gr.Checkbox(
            value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
        )
        predict_button = gr.Button("預測", elem_id="predict-button")

    with gr.Accordion("進階設定", open=False):
        model_selector = gr.Dropdown(
            MODEL_LIST, value=MODEL_LIST[0], label="模型"
        )
        k_slider = gr.Slider(
            minimum=1, maximum=50, step=1, value=10, label="K(最大新詞元數)"
        )
        m_slider = gr.Slider(
            minimum=1, maximum=30, step=1, value=30, label="M(建議數/Beam 數)"
        )
        group_slider = gr.Slider(
            minimum=1, maximum=30, step=1, value=30,
            label="Beam 群組數 (num_beam_groups)"
        )
        diversity_penalty_slider = gr.Slider(
            minimum=0.0, maximum=2.0, step=0.1, value=1.0,
            label="多樣性懲罰 (diversity_penalty)"
        )

    predict_button.click(
        fn=suggest_next,
        inputs=[
            input_text,
            model_selector,
            k_slider,
            m_slider,
            group_slider,
            diversity_penalty_slider
        ],
        outputs=suggestions,
    )
    input_text.change(
        fn=lambda txt, mdl, k, m, g, d, auto: (
            suggest_next(txt, mdl, k, m, g, d)
            if auto else update(choices=[], value=None)
        ),
        inputs=[
            input_text,
            model_selector,
            k_slider,
            m_slider,
            group_slider,
            diversity_penalty_slider,
            auto_predict
        ],
        outputs=suggestions,
    )
    suggestions.change(
        fn=append_suggestion,
        inputs=[input_text, suggestions],
        outputs=input_text,
    )

    demo.launch()