Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,432 Bytes
9c6147f dc13b6d f358820 dc13b6d 2f31c84 dc13b6d da40aec dc13b6d e894885 812c80c 2a9fd77 cc1322f dc13b6d 42dc39c dc13b6d 9c6147f dc13b6d 2f31c84 dc13b6d 9c6147f da40aec 3350989 fe26a7b 9c6147f 72f7051 3350989 dc13b6d 3350989 44ac7e8 3350989 44ac7e8 3350989 109cc2f ee21d1a 109cc2f dc13b6d 109cc2f 3350989 041841e 44ac7e8 da40aec 44ac7e8 da40aec a4c33c1 da40aec a4c33c1 da40aec 9c6147f dc13b6d 041841e dc13b6d 9c6147f dc13b6d 6277588 9c6147f 6277588 dc13b6d 3350989 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# app.py
import spaces
import gradio as gr
from gradio import update
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from opencc import OpenCC # 用於簡體轉繁體
# 初始化簡體到繁體轉換器
cc = OpenCC('s2t')
# 可選模型列表
MODEL_LIST = [
"liswei/Taiwan-ELM-270M",
"Mxode/SmolLM-Chinese-180M",
"flyingfishinwater/chinese-baby-llama2",
"unsloth/gemma-3-1b-pt",
"ckiplab/gpt2-tiny-chinese",
"ckiplab/gpt2-base-chinese",
"liswei/Taiwan-ELM-1_1B",
"benchang1110/Qwen2.5-Taiwan-1.5B-Instruct",
"benchang1110/Taiwan-tinyllama-v1.0-base",
"lianghsun/Llama-3.2-Taiwan-3B",
"twinkle-ai/Llama-3.2-3B-F1-Instruct",
"Epiculous/Violet_Twilight-v0.2",
]
def merge_common_prefixes(suggestions, min_len=2):
prefixes = []
to_remove = set()
for i in range(len(suggestions)):
for j in range(i+1, len(suggestions)):
s1, s2 = suggestions[i], suggestions[j]
common = ''.join(c1 for c1, c2 in zip(s1, s2) if c1 == c2)
if len(common) >= min_len:
prefixes.append(common)
to_remove.update([s1, s2])
unique_prefixes = []
for p in prefixes:
if p not in unique_prefixes:
unique_prefixes.append(p)
remainder = [s for s in suggestions if s not in to_remove]
return unique_prefixes + remainder
@lru_cache(maxsize=8)
def get_pipeline(model_name):
tok = AutoTokenizer.from_pretrained(model_name)
mdl = AutoModelForCausalLM.from_pretrained(
model_name, weights_only=False, trust_remote_code=True
)
mdl.to("cuda")
return pipeline("text-generation", model=mdl, tokenizer=tok, device=0)
@spaces.GPU
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
gen_pipe = get_pipeline(model_name)
# 構造 generate 參數字典,僅在 penalty>0 時加入 diversity 相關
gen_kwargs = {
"max_new_tokens": k,
"num_beams": m,
"num_return_sequences": m,
"do_sample": False,
"early_stopping": True,
}
if diversity_penalty and diversity_penalty > 0:
gen_kwargs["num_beam_groups"] = num_beam_groups
gen_kwargs["diversity_penalty"] = diversity_penalty
outs = gen_pipe(text, **gen_kwargs)
# 提取純下文、過濾空字串、繁體化
suggestions = [
cc.convert(out["generated_text"][len(text):].strip())
for out in outs
if out["generated_text"][len(text):].strip()
]
# 去重
unique_suggestions = []
for s in suggestions:
if s not in unique_suggestions:
unique_suggestions.append(s)
# 合併共同前綴
final_suggestions = merge_common_prefixes(unique_suggestions, min_len=2)
return update(choices=final_suggestions, value=None)
def append_suggestion(text, choice):
return text + choice
with gr.Blocks(css="""
#suggestions-bar { width: 100%; margin-bottom: 8px; }
#suggestions-bar .candidate-list {
display: flex; gap: 8px; background: #fff;
border: 1px solid #999; border-radius: 4px;
padding: 6px; overflow-x: auto; white-space: nowrap;
}
#suggestions-bar .candidate-list label { cursor: pointer; }
""") as demo:
with gr.Column():
suggestions = gr.Radio(
[], label="", interactive=True, type="value",
elem_id="suggestions-bar", elem_classes="candidate-list"
)
input_text = gr.Textbox(
label="", placeholder="請輸入拼音或文字…",
lines=1, max_lines=20, elem_id="input-box"
)
with gr.Row():
auto_predict = gr.Checkbox(
value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
)
predict_button = gr.Button("預測", elem_id="predict-button")
with gr.Accordion("進階設定", open=False):
model_selector = gr.Dropdown(
MODEL_LIST, value=MODEL_LIST[0], label="模型"
)
k_slider = gr.Slider(
minimum=1, maximum=50, step=1, value=10, label="K(最大新詞元數)"
)
m_slider = gr.Slider(
minimum=1, maximum=30, step=1, value=30, label="M(建議數/Beam 數)"
)
group_slider = gr.Slider(
minimum=1, maximum=30, step=1, value=30,
label="Beam 群組數 (num_beam_groups)"
)
diversity_penalty_slider = gr.Slider(
minimum=0.0, maximum=2.0, step=0.1, value=1.0,
label="多樣性懲罰 (diversity_penalty)"
)
predict_button.click(
fn=suggest_next,
inputs=[
input_text,
model_selector,
k_slider,
m_slider,
group_slider,
diversity_penalty_slider
],
outputs=suggestions,
)
input_text.change(
fn=lambda txt, mdl, k, m, g, d, auto: (
suggest_next(txt, mdl, k, m, g, d)
if auto else update(choices=[], value=None)
),
inputs=[
input_text,
model_selector,
k_slider,
m_slider,
group_slider,
diversity_penalty_slider,
auto_predict
],
outputs=suggestions,
)
suggestions.change(
fn=append_suggestion,
inputs=[input_text, suggestions],
outputs=input_text,
)
demo.launch()
|