import os import zipfile import pickle from glob import glob from pathlib import Path import pandas as pd import gradio as gr from indexrl.training import ( DynamicBuffer, create_model, save_model, explore, train_iter, ) from indexrl.environment import IndexRLEnv from indexrl.utils import get_n_channels, state_to_expression max_exp_len = 12 data_dir = "data/" global_logs_dir = os.path.join(data_dir, "logs") os.makedirs(data_dir, exist_ok=True) meta_data_file = os.path.join(data_dir, "metadata.csv") if not os.path.exists(meta_data_file): with open(meta_data_file, "w") as fp: fp.write("Name,Channels,Path\n") def save_dataset(name, zip): with zipfile.ZipFile(zip.name, "r") as zip_ref: data_path = os.path.join(data_dir, name) zip_ref.extractall(data_path) img_path = glob(os.path.join(data_path, "images", "*.npy"))[0] n_channels = get_n_channels(img_path) with open(meta_data_file, "a") as fp: fp.write(f"{name},{n_channels},{data_path}\n") meta_data_df = pd.read_csv(meta_data_file) return meta_data_df, gr.Dropdown.update(choices=meta_data_df["Name"].to_list()) def get_tree(exp_num: int = 1, tree_num: int = 1): tree_num = max(tree_num, 1) tree_path = os.path.join( global_logs_dir, f"tree_{int(exp_num)}_{int(tree_num)}.txt" ) if os.path.exists(tree_path): with open(tree_path, "r", encoding="utf-8") as fp: tree = fp.read() return tree print(f"Tree at {tree_path} not found!") return "" def change_expression(exp_num: int = 1, tree_num: int = 1): try: paths = glob(os.path.join(global_logs_dir, f"tree_{int(exp_num)}_*.txt")) except TypeError: return "", gr.Slider.update() tree_num = max(min(len(paths), tree_num), 1) tree = get_tree(exp_num, tree_num) return tree, gr.Slider.update(value=tree_num, maximum=len(paths), interactive=True) def find_expression(dataset_name: str): if dataset_name == "": return ("", gr.Slider.update(value=1, interactive=False)) global global_logs_dir meta_data_df = pd.read_csv(meta_data_file, index_col="Name") n_channels = meta_data_df["Channels"][dataset_name] data_dir = meta_data_df["Path"][dataset_name] image_dir = os.path.join(data_dir, "images") mask_dir = os.path.join(data_dir, "masks") cache_dir = os.path.join(data_dir, "cache") global_logs_dir = logs_dir = os.path.join(data_dir, "logs") models_dir = os.path.join(data_dir, "models") for dir_name in (cache_dir, logs_dir, models_dir): Path(dir_name).mkdir(parents=True, exist_ok=True) action_list = ( list("()+-*/=") + ["sq", "sqrt"] + [f"c{c}" for c in range(n_channels)] ) env = IndexRLEnv(action_list, max_exp_len) agent, optimizer = create_model(len(action_list)) seen_path = os.path.join(cache_dir, "seen.pkl") if cache_dir else "" env.save_seen(seen_path) data_buffer = DynamicBuffer() i = 0 while True: i += 1 print(f"----------------\nIteration {i}") print("Collecting data...") data = explore( env.copy(), agent, image_dir, mask_dir, 1, logs_dir, seen_path, tree_prefix=f"tree_{int(i)}", n_iters=1000, ) print( f"Data collection done. Collected {len(data)} examples. Buffer size = {len(data_buffer)}." ) data_buffer.add_data(data) print(f"Buffer size new = {len(data_buffer)}.") agent, optimizer, loss = train_iter(agent, optimizer, data_buffer) print("Loss:", loss) i_str = str(i).rjust(3, "0") if models_dir: save_model(agent, f"{models_dir}/model_{i_str}_loss-{loss}.pt") if cache_dir: with open(f"{cache_dir}/data_buffer_{i_str}.pkl", "wb") as fp: pickle.dump(data_buffer, fp) tree = get_tree() top_5 = data_buffer.get_top_n(5) top_5_str = "\n".join( map( lambda x: " ".join(state_to_expression(x[0], action_list)) + " " + str(x[1]), top_5, ) ) yield top_5_str, gr.Slider.update(value=i, maximum=i, interactive=True) with gr.Blocks(title="IndexRL") as demo: gr.Markdown("# IndexRL") meta_data_df = pd.read_csv(meta_data_file) with gr.Tab("Find Expressions"): with gr.Row(): with gr.Column(): select_dataset = gr.Dropdown( label="Select Dataset", choices=meta_data_df["Name"].to_list(), ) find_exp_btn = gr.Button("Find Expressions", variant="primary") stop_btn = gr.Button("Stop", variant="stop") best_exps = gr.Textbox(label="Best Expressions", interactive=False) with gr.Column(): select_exp = gr.Slider( value=1, label="Iteration", interactive=False, minimum=1, step=1 ) select_tree = gr.Slider( value=1, label="Tree Number", interactive=False, minimum=1, step=1 ) out_exp_tree = gr.Textbox( label="Latest Expression Tree", interactive=False ) with gr.Tab("Datasets"): dataset_upload = gr.File(label="Upload Data ZIP file") dataset_name = gr.Textbox(label="Dataset Name") dataset_upload_btn = gr.Button("Upload") dataset_table = gr.Dataframe(meta_data_df, label="Dataset Table") find_exp_event = find_exp_btn.click( find_expression, inputs=[select_dataset], outputs=[best_exps, select_exp], ) stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[find_exp_event]) select_exp.change( fn=lambda x, y: change_expression(x, y), inputs=[select_exp, select_tree], outputs=[out_exp_tree, select_tree], ) select_tree.change( fn=lambda x, y: get_tree(x, y), inputs=[select_exp, select_tree], outputs=out_exp_tree, ) dataset_upload.upload( lambda x: ".".join(os.path.basename(x.orig_name).split(".")[:-1]), inputs=dataset_upload, outputs=dataset_name, ) dataset_upload_btn.click( save_dataset, inputs=[dataset_name, dataset_upload], outputs=[dataset_table, select_dataset], ) demo.queue(concurrency_count=10).launch(debug=True)