Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -82,7 +82,7 @@ class WeightedTrainer(Trainer):
|
|
82 |
return (loss, outputs) if return_outputs else loss
|
83 |
|
84 |
# fine-tuning function
|
85 |
-
def train_function_no_sweeps(base_model_path
|
86 |
|
87 |
# Set the LoRA config
|
88 |
config = {
|
@@ -170,7 +170,14 @@ def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
|
|
170 |
tokenizer.save_pretrained(save_path)
|
171 |
|
172 |
return save_path
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
# Load the data from pickle files (replace with your local paths)
|
175 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
176 |
train_sequences = pickle.load(f)
|
@@ -198,6 +205,7 @@ test_labels = truncate_labels(test_labels, max_sequence_length)
|
|
198 |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
199 |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
200 |
|
|
|
201 |
# Compute Class Weights
|
202 |
classes = [0, 1]
|
203 |
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
@@ -248,10 +256,46 @@ saved_path = train_function_no_sweeps(base_model_path,train_dataset, test_datase
|
|
248 |
|
249 |
# debug result
|
250 |
dubug_result = saved_path #predictions #class_weights
|
|
|
251 |
|
252 |
demo = gr.Blocks(title="DEMO FOR ESM2Bind")
|
253 |
|
254 |
with demo:
|
255 |
gr.Markdown("# DEMO FOR ESM2Bind")
|
256 |
-
gr.Textbox(dubug_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
demo.launch()
|
|
|
82 |
return (loss, outputs) if return_outputs else loss
|
83 |
|
84 |
# fine-tuning function
|
85 |
+
def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
|
86 |
|
87 |
# Set the LoRA config
|
88 |
config = {
|
|
|
170 |
tokenizer.save_pretrained(save_path)
|
171 |
|
172 |
return save_path
|
173 |
+
|
174 |
+
# Constants & Globals
|
175 |
+
MODEL_OPTIONS = [
|
176 |
+
"facebook/esm2_t6_8M_UR50D",
|
177 |
+
"facebook/esm2_t12_35M_UR50D",
|
178 |
+
"facebook/esm2_t33_650M_UR50D",
|
179 |
+
] # models users can choose from
|
180 |
+
|
181 |
# Load the data from pickle files (replace with your local paths)
|
182 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
183 |
train_sequences = pickle.load(f)
|
|
|
205 |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
206 |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
207 |
|
208 |
+
'''
|
209 |
# Compute Class Weights
|
210 |
classes = [0, 1]
|
211 |
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
|
|
256 |
|
257 |
# debug result
|
258 |
dubug_result = saved_path #predictions #class_weights
|
259 |
+
'''
|
260 |
|
261 |
demo = gr.Blocks(title="DEMO FOR ESM2Bind")
|
262 |
|
263 |
with demo:
|
264 |
gr.Markdown("# DEMO FOR ESM2Bind")
|
265 |
+
#gr.Textbox(dubug_result)
|
266 |
+
|
267 |
+
with gr.Tab("Finetune Pre-trained Model"):
|
268 |
+
gr.Markdown("## Finetune Pre-trained Model")
|
269 |
+
with gr.Column():
|
270 |
+
gr.Markdown("## Load Inputs & Select Parameters")
|
271 |
+
gr.Markdown(
|
272 |
+
""" Pick a dataset, a model & adjust params (_optional_), and press **Finetune Pre-trained Model!"""
|
273 |
+
)
|
274 |
+
with gr.Row():
|
275 |
+
with gr.Column(scale=0.5, variant="compact"):
|
276 |
+
base_model_name = gr.Dropdown(
|
277 |
+
choices=MODEL_OPTIONS,
|
278 |
+
value=MODEL_OPTIONS[0],
|
279 |
+
label="Base Model Name",
|
280 |
+
interactive = True,
|
281 |
+
)
|
282 |
+
finetune_button = gr.Button(
|
283 |
+
value="Finetune Pre-trained Model",
|
284 |
+
interactive=True,
|
285 |
+
variant="primary",
|
286 |
+
)
|
287 |
+
finetune_output_text = gr.Textbox(
|
288 |
+
lines=1,
|
289 |
+
max_lines=12,
|
290 |
+
label="Finetune Status",
|
291 |
+
placeholder="Finetune Status Shown Here",
|
292 |
+
)
|
293 |
+
# Tab "Finetune Pre-trained Model" actions
|
294 |
+
finetune_button.click(
|
295 |
+
fn = train_function_no_sweeps,
|
296 |
+
inputs=[base_model_name], #finetune_dataset_name],
|
297 |
+
outputs = [finetune_output_text],
|
298 |
+
)
|
299 |
+
|
300 |
+
|
301 |
demo.launch()
|