dilithjay commited on
Commit
2dff12e
·
1 Parent(s): 0c4c4e8

Add app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +148 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import pickle
4
+ from glob import glob
5
+ from pathlib import Path
6
+
7
+ import pandas as pd
8
+ import gradio as gr
9
+
10
+ from indexrl.training import (
11
+ DynamicBuffer,
12
+ create_model,
13
+ save_model,
14
+ explore,
15
+ train_iter,
16
+ )
17
+ from indexrl.environment import IndexRLEnv
18
+ from indexrl.utils import get_n_channels, state_to_expression
19
+
20
+ data_dir = "data/"
21
+ os.makedirs(data_dir, exist_ok=True)
22
+
23
+ meta_data_file = os.path.join(data_dir, "metadata.csv")
24
+ if not os.path.exists(meta_data_file):
25
+ with open(meta_data_file, "w") as fp:
26
+ fp.write("Name,Channels,Path\n")
27
+
28
+
29
+ def save_dataset(name, zip):
30
+ with zipfile.ZipFile(zip.name, "r") as zip_ref:
31
+ data_path = os.path.join(data_dir, name)
32
+ zip_ref.extractall(data_path)
33
+
34
+ img_path = glob(os.path.join(data_path, "images", "*.npy"))[0]
35
+ n_channels = get_n_channels(img_path)
36
+
37
+ with open(meta_data_file, "a") as fp:
38
+ fp.write(f"{name},{n_channels},{data_path}\n")
39
+ meta_data_df = pd.read_csv(meta_data_file)
40
+ return meta_data_df
41
+
42
+
43
+ def find_expression(dataset_name: str):
44
+ meta_data_df = pd.read_csv(meta_data_file, index_col="Name")
45
+ n_channels = meta_data_df["Channels"][dataset_name]
46
+ data_dir = meta_data_df["Path"][dataset_name]
47
+
48
+ image_dir = os.path.join(data_dir, "images")
49
+ mask_dir = os.path.join(data_dir, "masks")
50
+
51
+ cache_dir = os.path.join(data_dir, "cache")
52
+ logs_dir = os.path.join(data_dir, "logs")
53
+ models_dir = os.path.join(data_dir, "models")
54
+ for dir_name in (cache_dir, logs_dir, models_dir):
55
+ Path(dir_name).mkdir(parents=True, exist_ok=True)
56
+
57
+ action_list = (
58
+ list("()+-*/=") + ["sq", "sqrt"] + [f"c{c}" for c in range(n_channels)]
59
+ )
60
+ env = IndexRLEnv(action_list, 12)
61
+ agent, optimizer = create_model(len(action_list))
62
+ seen_path = os.path.join(cache_dir, "seen.pkl") if cache_dir else ""
63
+ env.save_seen(seen_path)
64
+ data_buffer = DynamicBuffer()
65
+
66
+ i = 0
67
+ while True:
68
+ i += 1
69
+ print(f"----------------\nIteration {i}")
70
+ print("Collecting data...")
71
+ data = explore(
72
+ env.copy(),
73
+ agent,
74
+ image_dir,
75
+ mask_dir,
76
+ 1,
77
+ logs_dir,
78
+ seen_path,
79
+ n_iters=1000,
80
+ )
81
+ print(
82
+ f"Data collection done. Collected {len(data)} examples. Buffer size = {len(data_buffer)}."
83
+ )
84
+
85
+ data_buffer.add_data(data)
86
+ print(f"Buffer size new = {len(data_buffer)}.")
87
+
88
+ agent, optimizer, loss = train_iter(agent, optimizer, data_buffer)
89
+
90
+ i_str = str(i).rjust(3, "0")
91
+ if models_dir:
92
+ save_model(agent, f"{models_dir}/model_{i_str}_loss-{loss}.pt")
93
+ if cache_dir:
94
+ with open(f"{cache_dir}/data_buffer_{i_str}.pkl", "wb") as fp:
95
+ pickle.dump(data_buffer, fp)
96
+
97
+ with open(os.path.join(logs_dir, "tree_1.txt"), "r", encoding="utf-8") as fp:
98
+ tree = fp.read()
99
+
100
+ top_5 = data_buffer.get_top_n(5)
101
+ top_5_str = "\n".join(
102
+ map(
103
+ lambda x: " ".join(state_to_expression(x[0], action_list))
104
+ + " "
105
+ + str(x[1]),
106
+ top_5,
107
+ )
108
+ )
109
+
110
+ yield tree, top_5_str
111
+
112
+
113
+ with gr.Blocks() as demo:
114
+ gr.Markdown("# IndexRL")
115
+ meta_data_df = pd.read_csv(meta_data_file)
116
+
117
+ with gr.Tab("Find Expressions"):
118
+ select_dataset = gr.Dropdown(
119
+ label="Select Dataset",
120
+ choices=meta_data_df["Name"].to_list(),
121
+ )
122
+ find_exp_btn = gr.Button("Find Expressions")
123
+ stop_btn = gr.Button("Stop")
124
+ out_exp_tree = gr.Textbox(label="Latest Expression Tree", interactive=False)
125
+ best_exps = gr.Textbox(label="Best Expressions", interactive=False)
126
+
127
+ with gr.Tab("Datasets"):
128
+ dataset_upload = gr.File(label="Upload Data ZIP file")
129
+ dataset_name = gr.Textbox(label="Dataset Name")
130
+ dataset_upload_btn = gr.Button("Upload")
131
+
132
+ dataset_table = gr.Dataframe(meta_data_df, label="Dataset Table")
133
+
134
+ find_exp_event = find_exp_btn.click(
135
+ find_expression, inputs=[select_dataset], outputs=[out_exp_tree, best_exps]
136
+ )
137
+ stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[find_exp_event])
138
+
139
+ dataset_upload.upload(
140
+ lambda x: ".".join(os.path.basename(x.orig_name).split(".")[:-1]),
141
+ inputs=dataset_upload,
142
+ outputs=dataset_name,
143
+ )
144
+ dataset_upload_btn.click(
145
+ save_dataset, inputs=[dataset_name, dataset_upload], outputs=[dataset_table]
146
+ )
147
+
148
+ demo.queue(concurrency_count=10).launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ indexrl==0.1.1
2
+ gradio==3.34.0