Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -33,15 +33,38 @@ except ImportError as e:
|
|
33 |
print(f"砖讙讬讗讛 讘讬讘讜讗 讛住驻专讬讜转: {str(e)}")
|
34 |
sys.exit(1)
|
35 |
|
36 |
-
# 专砖讬诪转 诪讜讚诇讬诐 转讜诪讻讬 注讘专讬转 讗讜 专讘-诇砖讜谞讬讬诐 拽讟谞讬诐
|
37 |
MODELS = {
|
38 |
"facebook/opt-125m": "诪讜讚诇 拽讟谉 (125M) 讻诇诇讬 - 诪讛讬专",
|
39 |
-
"onlplab/alephbert-base": "诪讜讚诇 BERT 讘注讘专讬转",
|
40 |
-
"avichr/heBERT": "诪讜讚诇 BERT 注讘专讬 谞讜住祝",
|
41 |
-
"google/mt5-small": "诪讜讚诇 T5 专讘-诇砖讜谞讬 转讜诪讱 注讘专讬转",
|
42 |
-
"xlm-roberta-base": "诪讜讚诇 专讘-诇砖讜谞讬 转讜诪讱 注讘专讬转",
|
43 |
-
"google/flan-t5-small": "诪讜讚诇 讛谞讞讬讜转 拽讟谉 转讜诪讱 讘诪讙讜讜谉 砖驻讜转",
|
44 |
-
"distilgpt2": "诪讜讚诇 GPT-2 拽讟谉 讜诪讛讬专"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
}
|
46 |
|
47 |
# 诪转讞讘专 诇讞砖讘讜谉 Hugging Face
|
@@ -62,6 +85,15 @@ generator = None
|
|
62 |
def load_model(model_name, status_box=None):
|
63 |
"""讟注讬谞转 诪讜讚诇 讜注讚讻讜谉 住讟讟讜住"""
|
64 |
global generator, current_model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
current_model_name = model_name
|
66 |
|
67 |
if status_box is not None:
|
@@ -129,28 +161,32 @@ def ask_model(prompt):
|
|
129 |
outputs = generator(prompt, top_k=5)
|
130 |
return "\n".join([f"{item['token_str']} (讜讚讗讜转: {item['score']:.4f})" for item in outputs])
|
131 |
|
132 |
-
elif "t5" in current_model_name.lower() or "mt5" in current_model_name.lower() or "flan" in current_model_name.lower():
|
133 |
# 讟讬驻讜诇 讘诪讜讚诇讬 T5
|
134 |
outputs = generator(
|
135 |
prompt,
|
136 |
-
max_length=
|
137 |
-
do_sample=True,
|
138 |
-
temperature=0.
|
139 |
-
top_p=0.
|
|
|
|
|
140 |
)
|
141 |
if isinstance(outputs, list) and len(outputs) > 0:
|
142 |
return outputs[0]["generated_text"]
|
143 |
else:
|
144 |
return str(outputs)
|
145 |
else:
|
146 |
-
# 讟讬驻讜诇 讘诪讜讚诇讬诐 专讙讬诇讬诐
|
147 |
outputs = generator(
|
148 |
prompt,
|
149 |
-
max_new_tokens=
|
150 |
-
do_sample=True,
|
151 |
-
temperature=0.7,
|
152 |
-
top_p=0.
|
153 |
-
|
|
|
|
|
154 |
)
|
155 |
|
156 |
# 诪讞讝讬专 讗转 讛讟拽住讟 砖谞讜爪专
|
|
|
33 |
print(f"砖讙讬讗讛 讘讬讘讜讗 讛住驻专讬讜转: {str(e)}")
|
34 |
sys.exit(1)
|
35 |
|
36 |
+
# 专砖讬诪转 诪讜讚诇讬诐 转讜诪讻讬 注讘专讬转 讗讜 专讘-诇砖讜谞讬讬诐 拽讟谞讬诐 注诐 讛诪诇爪讜转 诇讛讙讚专讜转
|
37 |
MODELS = {
|
38 |
"facebook/opt-125m": "诪讜讚诇 拽讟谉 (125M) 讻诇诇讬 - 诪讛讬专",
|
39 |
+
"onlplab/alephbert-base": "诪讜讚诇 BERT 讘注讘专讬转 - 诪转讗讬诐 诇讛砖诇诪转 诪讬诇讬诐",
|
40 |
+
"avichr/heBERT": "诪讜讚诇 BERT 注讘专讬 谞讜住祝 - 诪转讗讬诐 诇讛砖诇诪转 诪讬诇讬诐",
|
41 |
+
"google/mt5-small": "诪讜讚诇 T5 专讘-诇砖讜谞讬 转讜诪讱 注讘专讬转 - 诪转讗讬诐 诇转专讙讜诐 讜砖讗诇讜转",
|
42 |
+
"xlm-roberta-base": "诪讜讚诇 专讘-诇砖讜谞讬 转讜诪讱 注讘专讬转 - 诪转讗讬诐 诇讛砖诇诪转 诪讬诇讬诐",
|
43 |
+
"google/flan-t5-small": "诪讜讚诇 讛谞讞讬讜转 拽讟谉 转讜诪讱 讘诪讙讜讜谉 砖驻讜转 - 讟讜讘 诇砖讗诇讜转 讜转砖讜讘讜转",
|
44 |
+
"distilgpt2": "诪讜讚诇 GPT-2 拽讟谉 讜诪讛讬专 - 讟讜讘 诇讬爪讬专转 讟拽住讟"
|
45 |
+
}
|
46 |
+
|
47 |
+
# 诪讬诇讜谉 讛讙讚专讜转 讗讜驻讟讬诪诇讬讜转 诇驻讬 住讜讙讬 诪讜讚诇讬诐
|
48 |
+
MODEL_CONFIGS = {
|
49 |
+
"bert": { # 诇讛砖诇诪转 诪住讻讜转
|
50 |
+
"top_k": 5
|
51 |
+
},
|
52 |
+
"t5": { # 诇诪讜讚诇讬 T5 讜-MT5
|
53 |
+
"max_length": 150,
|
54 |
+
"do_sample": True,
|
55 |
+
"temperature": 0.6,
|
56 |
+
"top_p": 0.92,
|
57 |
+
"repetition_penalty": 1.2,
|
58 |
+
"num_beams": 4
|
59 |
+
},
|
60 |
+
"default": { # 诇诪讜讚诇讬 OPT, GPT 讜讻讜'
|
61 |
+
"max_new_tokens": 150,
|
62 |
+
"do_sample": True,
|
63 |
+
"temperature": 0.7,
|
64 |
+
"top_p": 0.92,
|
65 |
+
"repetition_penalty": 1.1,
|
66 |
+
"no_repeat_ngram_size": 2
|
67 |
+
}
|
68 |
}
|
69 |
|
70 |
# 诪转讞讘专 诇讞砖讘讜谉 Hugging Face
|
|
|
85 |
def load_model(model_name, status_box=None):
|
86 |
"""讟注讬谞转 诪讜讚诇 讜注讚讻讜谉 住讟讟讜住"""
|
87 |
global generator, current_model_name
|
88 |
+
|
89 |
+
# 砖讞专讜专 诪砖讗讘讬诐 砖诇 诪讜讚诇 拽讜讚诐 讗诐 拽讬讬诐
|
90 |
+
if generator is not None:
|
91 |
+
import gc
|
92 |
+
del generator
|
93 |
+
gc.collect()
|
94 |
+
if torch.cuda.is_available():
|
95 |
+
torch.cuda.empty_cache()
|
96 |
+
|
97 |
current_model_name = model_name
|
98 |
|
99 |
if status_box is not None:
|
|
|
161 |
outputs = generator(prompt, top_k=5)
|
162 |
return "\n".join([f"{item['token_str']} (讜讚讗讜转: {item['score']:.4f})" for item in outputs])
|
163 |
|
164 |
+
elif "t5" in current_model_name.lower() or "mt5" in current_model_name.lower() or "flan-t5" in current_model_name.lower():
|
165 |
# 讟讬驻讜诇 讘诪讜讚诇讬 T5
|
166 |
outputs = generator(
|
167 |
prompt,
|
168 |
+
max_length=150, # 讗讜专讱 转讜爪讗讛 诪拽住讬诪诇讬
|
169 |
+
do_sample=True, # 讚讙讬诪讛 讗拽专讗讬转 讘诪拽讜诐 greedy
|
170 |
+
temperature=0.6, # 讟诪驻专讟讜专讛 谞诪讜讻讛 讬讜转专 诇转砖讜讘讜转 讬讜转专 诪讚讜讬拽讜转
|
171 |
+
top_p=0.92, # 谞讜拽诇讬讜住 讚讙讬诪讛 - 砖讜诪专 注诇 诪讙讜讜谉 转砖讜讘讜转
|
172 |
+
repetition_penalty=1.2, # 诪讜谞注 讞讝专讜转
|
173 |
+
num_beams=4 # 讞讬驻讜砖 拽专谉 诇转讜爪讗讜转 讗讬讻讜转讬讜转 讬讜转专
|
174 |
)
|
175 |
if isinstance(outputs, list) and len(outputs) > 0:
|
176 |
return outputs[0]["generated_text"]
|
177 |
else:
|
178 |
return str(outputs)
|
179 |
else:
|
180 |
+
# 讟讬驻讜诇 讘诪讜讚诇讬诐 专讙讬诇讬诐 (OPT, GPT)
|
181 |
outputs = generator(
|
182 |
prompt,
|
183 |
+
max_new_tokens=150, # 讗讜专讱 转讜爪讗讛 诪拽住讬诪诇讬
|
184 |
+
do_sample=True, # 讚讙讬诪讛 讗拽专讗讬转 讘诪拽讜诐 greedy
|
185 |
+
temperature=0.7, # 讗讬讝讜谉 讘讬谉 讚讬讜拽 讜讬爪讬专转讬讜转
|
186 |
+
top_p=0.92, # 谞讜拽诇讬讜住 讚讙讬诪讛 - 砖讜诪专 注诇 诪讙讜讜谉 转砖讜讘讜转
|
187 |
+
repetition_penalty=1.1, # 诪讜谞注 讞讝专讜转
|
188 |
+
no_repeat_ngram_size=2, # 诪谞讬注转 讞讝专讛 注诇 讘讬讙专诪讜转 (讝讜讙讜转 诪讬诇讬诐)
|
189 |
+
return_full_text=False # 诪讞讝讬专 专拽 讗转 讛讟拽住讟 讛讞讚砖 砖谞讜爪专
|
190 |
)
|
191 |
|
192 |
# 诪讞讝讬专 讗转 讛讟拽住讟 砖谞讜爪专
|