File size: 4,616 Bytes
2dff12e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import zipfile
import pickle
from glob import glob
from pathlib import Path
import pandas as pd
import gradio as gr
from indexrl.training import (
DynamicBuffer,
create_model,
save_model,
explore,
train_iter,
)
from indexrl.environment import IndexRLEnv
from indexrl.utils import get_n_channels, state_to_expression
data_dir = "data/"
os.makedirs(data_dir, exist_ok=True)
meta_data_file = os.path.join(data_dir, "metadata.csv")
if not os.path.exists(meta_data_file):
with open(meta_data_file, "w") as fp:
fp.write("Name,Channels,Path\n")
def save_dataset(name, zip):
with zipfile.ZipFile(zip.name, "r") as zip_ref:
data_path = os.path.join(data_dir, name)
zip_ref.extractall(data_path)
img_path = glob(os.path.join(data_path, "images", "*.npy"))[0]
n_channels = get_n_channels(img_path)
with open(meta_data_file, "a") as fp:
fp.write(f"{name},{n_channels},{data_path}\n")
meta_data_df = pd.read_csv(meta_data_file)
return meta_data_df
def find_expression(dataset_name: str):
meta_data_df = pd.read_csv(meta_data_file, index_col="Name")
n_channels = meta_data_df["Channels"][dataset_name]
data_dir = meta_data_df["Path"][dataset_name]
image_dir = os.path.join(data_dir, "images")
mask_dir = os.path.join(data_dir, "masks")
cache_dir = os.path.join(data_dir, "cache")
logs_dir = os.path.join(data_dir, "logs")
models_dir = os.path.join(data_dir, "models")
for dir_name in (cache_dir, logs_dir, models_dir):
Path(dir_name).mkdir(parents=True, exist_ok=True)
action_list = (
list("()+-*/=") + ["sq", "sqrt"] + [f"c{c}" for c in range(n_channels)]
)
env = IndexRLEnv(action_list, 12)
agent, optimizer = create_model(len(action_list))
seen_path = os.path.join(cache_dir, "seen.pkl") if cache_dir else ""
env.save_seen(seen_path)
data_buffer = DynamicBuffer()
i = 0
while True:
i += 1
print(f"----------------\nIteration {i}")
print("Collecting data...")
data = explore(
env.copy(),
agent,
image_dir,
mask_dir,
1,
logs_dir,
seen_path,
n_iters=1000,
)
print(
f"Data collection done. Collected {len(data)} examples. Buffer size = {len(data_buffer)}."
)
data_buffer.add_data(data)
print(f"Buffer size new = {len(data_buffer)}.")
agent, optimizer, loss = train_iter(agent, optimizer, data_buffer)
i_str = str(i).rjust(3, "0")
if models_dir:
save_model(agent, f"{models_dir}/model_{i_str}_loss-{loss}.pt")
if cache_dir:
with open(f"{cache_dir}/data_buffer_{i_str}.pkl", "wb") as fp:
pickle.dump(data_buffer, fp)
with open(os.path.join(logs_dir, "tree_1.txt"), "r", encoding="utf-8") as fp:
tree = fp.read()
top_5 = data_buffer.get_top_n(5)
top_5_str = "\n".join(
map(
lambda x: " ".join(state_to_expression(x[0], action_list))
+ " "
+ str(x[1]),
top_5,
)
)
yield tree, top_5_str
with gr.Blocks() as demo:
gr.Markdown("# IndexRL")
meta_data_df = pd.read_csv(meta_data_file)
with gr.Tab("Find Expressions"):
select_dataset = gr.Dropdown(
label="Select Dataset",
choices=meta_data_df["Name"].to_list(),
)
find_exp_btn = gr.Button("Find Expressions")
stop_btn = gr.Button("Stop")
out_exp_tree = gr.Textbox(label="Latest Expression Tree", interactive=False)
best_exps = gr.Textbox(label="Best Expressions", interactive=False)
with gr.Tab("Datasets"):
dataset_upload = gr.File(label="Upload Data ZIP file")
dataset_name = gr.Textbox(label="Dataset Name")
dataset_upload_btn = gr.Button("Upload")
dataset_table = gr.Dataframe(meta_data_df, label="Dataset Table")
find_exp_event = find_exp_btn.click(
find_expression, inputs=[select_dataset], outputs=[out_exp_tree, best_exps]
)
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[find_exp_event])
dataset_upload.upload(
lambda x: ".".join(os.path.basename(x.orig_name).split(".")[:-1]),
inputs=dataset_upload,
outputs=dataset_name,
)
dataset_upload_btn.click(
save_dataset, inputs=[dataset_name, dataset_upload], outputs=[dataset_table]
)
demo.queue(concurrency_count=10).launch(debug=True)
|