dror201031 commited on
Commit
f9ce600
verified
1 Parent(s): 2d6369f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -18
app.py CHANGED
@@ -33,15 +33,38 @@ except ImportError as e:
33
  print(f"砖讙讬讗讛 讘讬讘讜讗 讛住驻专讬讜转: {str(e)}")
34
  sys.exit(1)
35
 
36
- # 专砖讬诪转 诪讜讚诇讬诐 转讜诪讻讬 注讘专讬转 讗讜 专讘-诇砖讜谞讬讬诐 拽讟谞讬诐
37
  MODELS = {
38
  "facebook/opt-125m": "诪讜讚诇 拽讟谉 (125M) 讻诇诇讬 - 诪讛讬专",
39
- "onlplab/alephbert-base": "诪讜讚诇 BERT 讘注讘专讬转",
40
- "avichr/heBERT": "诪讜讚诇 BERT 注讘专讬 谞讜住祝",
41
- "google/mt5-small": "诪讜讚诇 T5 专讘-诇砖讜谞讬 转讜诪讱 注讘专讬转",
42
- "xlm-roberta-base": "诪讜讚诇 专讘-诇砖讜谞讬 转讜诪讱 注讘专讬转",
43
- "google/flan-t5-small": "诪讜讚诇 讛谞讞讬讜转 拽讟谉 转讜诪讱 讘诪讙讜讜谉 砖驻讜转",
44
- "distilgpt2": "诪讜讚诇 GPT-2 拽讟谉 讜诪讛讬专"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  }
46
 
47
  # 诪转讞讘专 诇讞砖讘讜谉 Hugging Face
@@ -62,6 +85,15 @@ generator = None
62
  def load_model(model_name, status_box=None):
63
  """讟注讬谞转 诪讜讚诇 讜注讚讻讜谉 住讟讟讜住"""
64
  global generator, current_model_name
 
 
 
 
 
 
 
 
 
65
  current_model_name = model_name
66
 
67
  if status_box is not None:
@@ -129,28 +161,32 @@ def ask_model(prompt):
129
  outputs = generator(prompt, top_k=5)
130
  return "\n".join([f"{item['token_str']} (讜讚讗讜转: {item['score']:.4f})" for item in outputs])
131
 
132
- elif "t5" in current_model_name.lower() or "mt5" in current_model_name.lower() or "flan" in current_model_name.lower():
133
  # 讟讬驻讜诇 讘诪讜讚诇讬 T5
134
  outputs = generator(
135
  prompt,
136
- max_length=100,
137
- do_sample=True,
138
- temperature=0.7,
139
- top_p=0.95
 
 
140
  )
141
  if isinstance(outputs, list) and len(outputs) > 0:
142
  return outputs[0]["generated_text"]
143
  else:
144
  return str(outputs)
145
  else:
146
- # 讟讬驻讜诇 讘诪讜讚诇讬诐 专讙讬诇讬诐
147
  outputs = generator(
148
  prompt,
149
- max_new_tokens=100,
150
- do_sample=True,
151
- temperature=0.7,
152
- top_p=0.95,
153
- return_full_text=False
 
 
154
  )
155
 
156
  # 诪讞讝讬专 讗转 讛讟拽住讟 砖谞讜爪专
 
33
  print(f"砖讙讬讗讛 讘讬讘讜讗 讛住驻专讬讜转: {str(e)}")
34
  sys.exit(1)
35
 
36
+ # 专砖讬诪转 诪讜讚诇讬诐 转讜诪讻讬 注讘专讬转 讗讜 专讘-诇砖讜谞讬讬诐 拽讟谞讬诐 注诐 讛诪诇爪讜转 诇讛讙讚专讜转
37
  MODELS = {
38
  "facebook/opt-125m": "诪讜讚诇 拽讟谉 (125M) 讻诇诇讬 - 诪讛讬专",
39
+ "onlplab/alephbert-base": "诪讜讚诇 BERT 讘注讘专讬转 - 诪转讗讬诐 诇讛砖诇诪转 诪讬诇讬诐",
40
+ "avichr/heBERT": "诪讜讚诇 BERT 注讘专讬 谞讜住祝 - 诪转讗讬诐 诇讛砖诇诪转 诪讬诇讬诐",
41
+ "google/mt5-small": "诪讜讚诇 T5 专讘-诇砖讜谞讬 转讜诪讱 注讘专讬转 - 诪转讗讬诐 诇转专讙讜诐 讜砖讗诇讜转",
42
+ "xlm-roberta-base": "诪讜讚诇 专讘-诇砖讜谞讬 转讜诪讱 注讘专讬转 - 诪转讗讬诐 诇讛砖诇诪转 诪讬诇讬诐",
43
+ "google/flan-t5-small": "诪讜讚诇 讛谞讞讬讜转 拽讟谉 转讜诪讱 讘诪讙讜讜谉 砖驻讜转 - 讟讜讘 诇砖讗诇讜转 讜转砖讜讘讜转",
44
+ "distilgpt2": "诪讜讚诇 GPT-2 拽讟谉 讜诪讛讬专 - 讟讜讘 诇讬爪讬专转 讟拽住讟"
45
+ }
46
+
47
+ # 诪讬诇讜谉 讛讙讚专讜转 讗讜驻讟讬诪诇讬讜转 诇驻讬 住讜讙讬 诪讜讚诇讬诐
48
+ MODEL_CONFIGS = {
49
+ "bert": { # 诇讛砖诇诪转 诪住讻讜转
50
+ "top_k": 5
51
+ },
52
+ "t5": { # 诇诪讜讚诇讬 T5 讜-MT5
53
+ "max_length": 150,
54
+ "do_sample": True,
55
+ "temperature": 0.6,
56
+ "top_p": 0.92,
57
+ "repetition_penalty": 1.2,
58
+ "num_beams": 4
59
+ },
60
+ "default": { # 诇诪讜讚诇讬 OPT, GPT 讜讻讜'
61
+ "max_new_tokens": 150,
62
+ "do_sample": True,
63
+ "temperature": 0.7,
64
+ "top_p": 0.92,
65
+ "repetition_penalty": 1.1,
66
+ "no_repeat_ngram_size": 2
67
+ }
68
  }
69
 
70
  # 诪转讞讘专 诇讞砖讘讜谉 Hugging Face
 
85
  def load_model(model_name, status_box=None):
86
  """讟注讬谞转 诪讜讚诇 讜注讚讻讜谉 住讟讟讜住"""
87
  global generator, current_model_name
88
+
89
+ # 砖讞专讜专 诪砖讗讘讬诐 砖诇 诪讜讚诇 拽讜讚诐 讗诐 拽讬讬诐
90
+ if generator is not None:
91
+ import gc
92
+ del generator
93
+ gc.collect()
94
+ if torch.cuda.is_available():
95
+ torch.cuda.empty_cache()
96
+
97
  current_model_name = model_name
98
 
99
  if status_box is not None:
 
161
  outputs = generator(prompt, top_k=5)
162
  return "\n".join([f"{item['token_str']} (讜讚讗讜转: {item['score']:.4f})" for item in outputs])
163
 
164
+ elif "t5" in current_model_name.lower() or "mt5" in current_model_name.lower() or "flan-t5" in current_model_name.lower():
165
  # 讟讬驻讜诇 讘诪讜讚诇讬 T5
166
  outputs = generator(
167
  prompt,
168
+ max_length=150, # 讗讜专讱 转讜爪讗讛 诪拽住讬诪诇讬
169
+ do_sample=True, # 讚讙讬诪讛 讗拽专讗讬转 讘诪拽讜诐 greedy
170
+ temperature=0.6, # 讟诪驻专讟讜专讛 谞诪讜讻讛 讬讜转专 诇转砖讜讘讜转 讬讜转专 诪讚讜讬拽讜转
171
+ top_p=0.92, # 谞讜拽诇讬讜住 讚讙讬诪讛 - 砖讜诪专 注诇 诪讙讜讜谉 转砖讜讘讜转
172
+ repetition_penalty=1.2, # 诪讜谞注 讞讝专讜转
173
+ num_beams=4 # 讞讬驻讜砖 拽专谉 诇转讜爪讗讜转 讗讬讻讜转讬讜转 讬讜转专
174
  )
175
  if isinstance(outputs, list) and len(outputs) > 0:
176
  return outputs[0]["generated_text"]
177
  else:
178
  return str(outputs)
179
  else:
180
+ # 讟讬驻讜诇 讘诪讜讚诇讬诐 专讙讬诇讬诐 (OPT, GPT)
181
  outputs = generator(
182
  prompt,
183
+ max_new_tokens=150, # 讗讜专讱 转讜爪讗讛 诪拽住讬诪诇讬
184
+ do_sample=True, # 讚讙讬诪讛 讗拽专讗讬转 讘诪拽讜诐 greedy
185
+ temperature=0.7, # 讗讬讝讜谉 讘讬谉 讚讬讜拽 讜讬爪讬专转讬讜转
186
+ top_p=0.92, # 谞讜拽诇讬讜住 讚讙讬诪讛 - 砖讜诪专 注诇 诪讙讜讜谉 转砖讜讘讜转
187
+ repetition_penalty=1.1, # 诪讜谞注 讞讝专讜转
188
+ no_repeat_ngram_size=2, # 诪谞讬注转 讞讝专讛 注诇 讘讬讙专诪讜转 (讝讜讙讜转 诪讬诇讬诐)
189
+ return_full_text=False # 诪讞讝讬专 专拽 讗转 讛讟拽住讟 讛讞讚砖 砖谞讜爪专
190
  )
191
 
192
  # 诪讞讝讬专 讗转 讛讟拽住讟 砖谞讜爪专