Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,801 Bytes
6d98a7c dc13b6d f358820 dc13b6d 6d98a7c 05df099 8b84de4 2f31c84 a94f020 dc13b6d da40aec dc13b6d e894885 812c80c 2a9fd77 cc1322f dc13b6d 42dc39c dc13b6d a94f020 9c6147f dc13b6d a94f020 2f31c84 6eb5ae5 a94f020 dc13b6d a94f020 6d98a7c da40aec 6d98a7c 3350989 a94f020 05df099 6ec4350 3350989 216d8ce 4a676fe 216d8ce 4a676fe a94f020 3350989 4a676fe dc13b6d 3b87f8e 6d98a7c 44ac7e8 6d98a7c 44ac7e8 6d98a7c 109cc2f ee21d1a 109cc2f dc13b6d 6d98a7c 109cc2f 6d98a7c 041841e 44ac7e8 da40aec 44ac7e8 da40aec a4c33c1 da40aec ab41c65 da40aec 9c6147f ab41c65 9c6147f 6d98a7c 9c6147f a94f020 dc13b6d 6d98a7c 041841e dc13b6d 9c6147f a94f020 9c6147f dc13b6d 6277588 a94f020 9c6147f a94f020 9c6147f 6277588 dc13b6d 6d98a7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
# app.py
import spaces
import gradio as gr
from gradio import update
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from opencc import OpenCC # 用於簡體轉繁體
from math import gcd
from termcolor import cprint
# 初始化簡體到繁體轉換器
cc = OpenCC('s2t')
tokenizer = None
# 可選模型列表
MODEL_LIST = [
"liswei/Taiwan-ELM-270M",
"Mxode/SmolLM-Chinese-180M",
"flyingfishinwater/chinese-baby-llama2",
"unsloth/gemma-3-1b-pt",
"ckiplab/gpt2-tiny-chinese",
"ckiplab/gpt2-base-chinese",
"liswei/Taiwan-ELM-1_1B",
"benchang1110/Qwen2.5-Taiwan-1.5B-Instruct",
"benchang1110/Taiwan-tinyllama-v1.0-base",
"lianghsun/Llama-3.2-Taiwan-3B",
"twinkle-ai/Llama-3.2-3B-F1-Instruct",
"Epiculous/Violet_Twilight-v0.2",
]
def clean_suggestions(suggestions: list[str], max_levels: int) -> list[str]:
"""
清洗建议列表:
1. 对每条建议用 tokenizer.tokenize 得到 token 序列。
2. 构建前缀树,将所有 token 序列插入。
3. 遍历前缀树,仅在深度 <= max_levels 且该节点有子节点时,提取对应 token 前缀。
4. 将这些 token 前缀转换回文本并去重,返回列表。
"""
# 定义 Trie 节点结构
class TrieNode:
__slots__ = ("children", "count")
def __init__(self):
self.children: dict[str, TrieNode] = {}
self.count: int = 0 # 可以记录有多少序列经过此节点(可选)
# 构建前缀树
root = TrieNode()
token_seqs: list[list[str]] = []
for text in suggestions:
# tokenizer.tokenize 可能返回子词 token 列表
try:
toks = tokenizer.tokenize(text)
except Exception:
# 如果 tokenizer 不支持直接 tokenize raw text,可以先用 basic tokenization,如按空白分割
toks = text.split()
if not toks:
continue
token_seqs.append(toks)
node = root
node.count += 1
for tok in toks:
if tok not in node.children:
node.children[tok] = TrieNode()
node = node.children[tok]
node.count += 1
# 遍历 Trie,收集深度 <= max_levels 且有子节点的前缀序列
results_prefix_tokens: set[tuple[str, ...]] = set()
def dfs(node: TrieNode, path: list[str], depth: int):
# node: 当前 TrieNode; path: 已走过的 token 列表; depth: len(path)
if depth > max_levels:
return
# 如果当前节点有子节点,且 depth>0 (排除根节点本身),则为一个候选前缀
if depth > 0 and node.children:
results_prefix_tokens.add(tuple(path))
# 继续往下遍历,直到 depth == max_levels
if depth == max_levels:
return
for tok, child in node.children.items():
path.append(tok)
dfs(child, path, depth + 1)
path.pop()
dfs(root, [], 0)
# 将 token 前缀转换回字符串
cleaned: set[str] = set()
for tok_prefix in results_prefix_tokens:
try:
# tokenizer.convert_tokens_to_string 在大多数 tokenizer 支持
text_pref = tokenizer.convert_tokens_to_string(list(tok_prefix)).strip()
except Exception:
# fallback: 直接拼接 token(可能需要根据 tokenizer 规范加空格或直接连起来)
text_pref = "".join(tok_prefix).strip()
if text_pref:
cleaned.add(text_pref)
# 返回去重之后的列表
return list(cleaned)
@lru_cache(maxsize=8)
def get_pipeline(model_name):
global tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
mdl = AutoModelForCausalLM.from_pretrained(
model_name, weights_only=False, trust_remote_code=True
)
try:
mdl.to("cuda")
except Exception as e:
print(f'Error: {e}')
return pipeline("text-generation", model=mdl, tokenizer=tokenizer, device=0)
@spaces.GPU
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty, max_prefix_levels=2):
"""
使用 Diverse Beam Search 產生 m 條候選:
- num_beams = m
- num_beam_groups, diversity_penalty 可調整多樣性
之後轉繁體、去重、合併共同前綴後回傳。
"""
gen_pipe = get_pipeline(model_name)
# 構造 generate 參數字典,僅在 penalty>0 時加入 diversity 相關
gen_kwargs = {
"max_new_tokens": k,
"num_beams": m,
"num_return_sequences": m,
"do_sample": False,
"early_stopping": True,
}
if diversity_penalty and diversity_penalty > 0:
valid_group = max(gcd(m, num_beam_groups),2)
gen_kwargs["num_beam_groups"] = valid_group
gen_kwargs["diversity_penalty"] = float(diversity_penalty)
outs = gen_pipe(text, **gen_kwargs)
# 提取純下文、過濾空字串、繁體化、確保 strip 處理
suggestions = set()
for out in outs:
snippet = out["generated_text"][len(text):].strip()
if not snippet:
continue
converted = cc.convert(snippet).strip()
suggestions.add(converted)
suggestions = list(suggestions)
suggestions = clean_suggestions(suggestions, max_prefix_levels)
return update(choices=suggestions, value=None)
def append_suggestion(current, choice):
if choice is None:
return current
# 直接插入選中的候選文字
return current + choice
# 自訂 CSS:模擬經典中文輸入法候選欄樣式,並優化手機響應與自動高度
custom_css = """
#suggestions-bar {
width: 100%;
margin-bottom: 8px;
}
#suggestions-bar .candidate-list {
display: flex;
gap: 8px;
background: #fff;
border: 1px solid #999;
border-radius: 4px;
padding: 6px;
overflow-x: auto;
white-space: nowrap;
}
#suggestions-bar .candidate-list label {
cursor: pointer;
padding: 6px 10px;
font-size: 16px;
}
#suggestions-bar .candidate-list label:hover {
background: #f5f5f5;
}
#suggestions-bar .candidate-list input[type=radio]:checked + label {
background: #e6f7ff;
border: 1px solid #1890ff;
}
#input-box textarea {
width: 100%;
font-size: 16px;
padding: 6px;
box-sizing: border-box;
overflow: hidden;
resize: none;
}
#predict-button {
margin-top: 8px;
width: 100%;
}
/* 手機響應式 */
@media only screen and (max-width: 600px) {
#suggestions-bar .candidate-list label {
padding: 8px;
font-size: 18px;
}
#predict-button {
font-size: 18px;
}
}
"""
# 自動增高腳本
auto_height_js = """
<script>
window.addEventListener('load', () => {
const textarea = document.querySelector('#input-box textarea');
if (!textarea) return;
textarea.style.height = 'auto';
textarea.addEventListener('input', function() {
this.style.height = 'auto';
this.style.height = this.scrollHeight + 'px';
});
});
</script>
"""
with gr.Blocks(css=custom_css) as demo:
gr.HTML(auto_height_js)
gr.Markdown(
"## 🇹🇼 繁體中文 IME 加速器 \
"
"結合小型語言模型與 ZeroGPU,提供即時輸入法風格候選欄。"
)
with gr.Column():
suggestions = gr.Radio(
[], label="", interactive=True, type="value",
elem_id="suggestions-bar", elem_classes="candidate-list"
)
input_text = gr.Textbox(
label="", placeholder="請輸入拼音或文字…",
lines=1, max_lines=20, elem_id="input-box"
)
# 永遠顯示預測按鈕
with gr.Row():
auto_predict = gr.Checkbox(
value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
)
predict_button = gr.Button(
"預測", elem_id="predict-button"
)
with gr.Accordion("進階設定", open=False):
model_selector = gr.Dropdown(
MODEL_LIST, value=MODEL_LIST[0], label="模型"
)
k_slider = gr.Slider(
minimum=1, maximum=50, step=1, value=10, label="K(最大新詞元數)"
)
m_slider = gr.Slider(
minimum=1, maximum=30, step=1, value=20, label="M(建議數/Beam 數)"
)
group_slider = gr.Slider(
minimum=1, maximum=30, step=1, value=6,
label="Beam 群組數 (num_beam_groups)"
)
diversity_penalty_slider = gr.Slider(
minimum=0.0, maximum=2.0, step=0.1, value=1.0,
label="多樣性懲罰 (diversity_penalty)"
)
prefix_levels_slider = gr.Slider(
minimum=1, maximum=5, step=1, value=2,
label="Clean 前綴深度 (max_levels)"
)
# 綁定事件
predict_button.click(
fn=suggest_next,
inputs=[
input_text,
model_selector,
k_slider,
m_slider,
group_slider,
diversity_penalty_slider,
prefix_levels_slider # 新增
],
outputs=suggestions,
)
input_text.change(
fn=lambda txt, mdl, k, m, g, d, auto, pl: (
suggest_next(txt, mdl, k, m, g, d, pl)
if auto else update(choices=[], value=None)
),
inputs=[
input_text,
model_selector,
k_slider,
m_slider,
group_slider,
diversity_penalty_slider,
auto_predict,
prefix_levels_slider # 新增
],
outputs=suggestions,
)
suggestions.change(
fn=append_suggestion,
inputs=[input_text, suggestions],
outputs=input_text,
)
demo.launch() |