wangjin2000 commited on
Commit
f699662
·
verified ·
1 Parent(s): d03eed6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -3
app.py CHANGED
@@ -82,7 +82,7 @@ class WeightedTrainer(Trainer):
82
  return (loss, outputs) if return_outputs else loss
83
 
84
  # fine-tuning function
85
- def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
86
 
87
  # Set the LoRA config
88
  config = {
@@ -170,7 +170,14 @@ def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
170
  tokenizer.save_pretrained(save_path)
171
 
172
  return save_path
173
-
 
 
 
 
 
 
 
174
  # Load the data from pickle files (replace with your local paths)
175
  with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
176
  train_sequences = pickle.load(f)
@@ -198,6 +205,7 @@ test_labels = truncate_labels(test_labels, max_sequence_length)
198
  train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
199
  test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
200
 
 
201
  # Compute Class Weights
202
  classes = [0, 1]
203
  flat_train_labels = [label for sublist in train_labels for label in sublist]
@@ -248,10 +256,46 @@ saved_path = train_function_no_sweeps(base_model_path,train_dataset, test_datase
248
 
249
  # debug result
250
  dubug_result = saved_path #predictions #class_weights
 
251
 
252
  demo = gr.Blocks(title="DEMO FOR ESM2Bind")
253
 
254
  with demo:
255
  gr.Markdown("# DEMO FOR ESM2Bind")
256
- gr.Textbox(dubug_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  demo.launch()
 
82
  return (loss, outputs) if return_outputs else loss
83
 
84
  # fine-tuning function
85
+ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
86
 
87
  # Set the LoRA config
88
  config = {
 
170
  tokenizer.save_pretrained(save_path)
171
 
172
  return save_path
173
+
174
+ # Constants & Globals
175
+ MODEL_OPTIONS = [
176
+ "facebook/esm2_t6_8M_UR50D",
177
+ "facebook/esm2_t12_35M_UR50D",
178
+ "facebook/esm2_t33_650M_UR50D",
179
+ ] # models users can choose from
180
+
181
  # Load the data from pickle files (replace with your local paths)
182
  with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
183
  train_sequences = pickle.load(f)
 
205
  train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
206
  test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
207
 
208
+ '''
209
  # Compute Class Weights
210
  classes = [0, 1]
211
  flat_train_labels = [label for sublist in train_labels for label in sublist]
 
256
 
257
  # debug result
258
  dubug_result = saved_path #predictions #class_weights
259
+ '''
260
 
261
  demo = gr.Blocks(title="DEMO FOR ESM2Bind")
262
 
263
  with demo:
264
  gr.Markdown("# DEMO FOR ESM2Bind")
265
+ #gr.Textbox(dubug_result)
266
+
267
+ with gr.Tab("Finetune Pre-trained Model"):
268
+ gr.Markdown("## Finetune Pre-trained Model")
269
+ with gr.Column():
270
+ gr.Markdown("## Load Inputs & Select Parameters")
271
+ gr.Markdown(
272
+ """ Pick a dataset, a model & adjust params (_optional_), and press **Finetune Pre-trained Model!"""
273
+ )
274
+ with gr.Row():
275
+ with gr.Column(scale=0.5, variant="compact"):
276
+ base_model_name = gr.Dropdown(
277
+ choices=MODEL_OPTIONS,
278
+ value=MODEL_OPTIONS[0],
279
+ label="Base Model Name",
280
+ interactive = True,
281
+ )
282
+ finetune_button = gr.Button(
283
+ value="Finetune Pre-trained Model",
284
+ interactive=True,
285
+ variant="primary",
286
+ )
287
+ finetune_output_text = gr.Textbox(
288
+ lines=1,
289
+ max_lines=12,
290
+ label="Finetune Status",
291
+ placeholder="Finetune Status Shown Here",
292
+ )
293
+ # Tab "Finetune Pre-trained Model" actions
294
+ finetune_button.click(
295
+ fn = train_function_no_sweeps,
296
+ inputs=[base_model_name], #finetune_dataset_name],
297
+ outputs = [finetune_output_text],
298
+ )
299
+
300
+
301
  demo.launch()