dilithjay commited on
Commit
90ca6f7
·
1 Parent(s): 779ba9f

Add multiple expression debugging support

Browse files
Files changed (1) hide show
  1. app.py +67 -14
app.py CHANGED
@@ -17,7 +17,10 @@ from indexrl.training import (
17
  from indexrl.environment import IndexRLEnv
18
  from indexrl.utils import get_n_channels, state_to_expression
19
 
 
 
20
  data_dir = "data/"
 
21
  os.makedirs(data_dir, exist_ok=True)
22
 
23
  meta_data_file = os.path.join(data_dir, "metadata.csv")
@@ -40,7 +43,31 @@ def save_dataset(name, zip):
40
  return meta_data_df, gr.Dropdown.update(choices=meta_data_df["Name"].to_list())
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def find_expression(dataset_name: str):
 
44
  meta_data_df = pd.read_csv(meta_data_file, index_col="Name")
45
  n_channels = meta_data_df["Channels"][dataset_name]
46
  data_dir = meta_data_df["Path"][dataset_name]
@@ -49,7 +76,7 @@ def find_expression(dataset_name: str):
49
  mask_dir = os.path.join(data_dir, "masks")
50
 
51
  cache_dir = os.path.join(data_dir, "cache")
52
- logs_dir = os.path.join(data_dir, "logs")
53
  models_dir = os.path.join(data_dir, "models")
54
  for dir_name in (cache_dir, logs_dir, models_dir):
55
  Path(dir_name).mkdir(parents=True, exist_ok=True)
@@ -57,7 +84,7 @@ def find_expression(dataset_name: str):
57
  action_list = (
58
  list("()+-*/=") + ["sq", "sqrt"] + [f"c{c}" for c in range(n_channels)]
59
  )
60
- env = IndexRLEnv(action_list, 12)
61
  agent, optimizer = create_model(len(action_list))
62
  seen_path = os.path.join(cache_dir, "seen.pkl") if cache_dir else ""
63
  env.save_seen(seen_path)
@@ -76,6 +103,7 @@ def find_expression(dataset_name: str):
76
  1,
77
  logs_dir,
78
  seen_path,
 
79
  n_iters=1000,
80
  )
81
  print(
@@ -95,8 +123,7 @@ def find_expression(dataset_name: str):
95
  with open(f"{cache_dir}/data_buffer_{i_str}.pkl", "wb") as fp:
96
  pickle.dump(data_buffer, fp)
97
 
98
- with open(os.path.join(logs_dir, "tree_1.txt"), "r", encoding="utf-8") as fp:
99
- tree = fp.read()
100
 
101
  top_5 = data_buffer.get_top_n(5)
102
  top_5_str = "\n".join(
@@ -108,7 +135,9 @@ def find_expression(dataset_name: str):
108
  )
109
  )
110
 
111
- yield tree, top_5_str
 
 
112
 
113
 
114
  with gr.Blocks(title="IndexRL") as demo:
@@ -116,14 +145,26 @@ with gr.Blocks(title="IndexRL") as demo:
116
  meta_data_df = pd.read_csv(meta_data_file)
117
 
118
  with gr.Tab("Find Expressions"):
119
- select_dataset = gr.Dropdown(
120
- label="Select Dataset",
121
- choices=meta_data_df["Name"].to_list(),
122
- )
123
- find_exp_btn = gr.Button("Find Expressions")
124
- stop_btn = gr.Button("Stop")
125
- best_exps = gr.Textbox(label="Best Expressions", interactive=False)
126
- out_exp_tree = gr.Textbox(label="Latest Expression Tree", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  with gr.Tab("Datasets"):
129
  dataset_upload = gr.File(label="Upload Data ZIP file")
@@ -133,9 +174,21 @@ with gr.Blocks(title="IndexRL") as demo:
133
  dataset_table = gr.Dataframe(meta_data_df, label="Dataset Table")
134
 
135
  find_exp_event = find_exp_btn.click(
136
- find_expression, inputs=[select_dataset], outputs=[out_exp_tree, best_exps]
 
 
137
  )
138
  stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[find_exp_event])
 
 
 
 
 
 
 
 
 
 
139
 
140
  dataset_upload.upload(
141
  lambda x: ".".join(os.path.basename(x.orig_name).split(".")[:-1]),
 
17
  from indexrl.environment import IndexRLEnv
18
  from indexrl.utils import get_n_channels, state_to_expression
19
 
20
+
21
+ max_exp_len = 12
22
  data_dir = "data/"
23
+ global_logs_dir = os.path.join(data_dir, "logs")
24
  os.makedirs(data_dir, exist_ok=True)
25
 
26
  meta_data_file = os.path.join(data_dir, "metadata.csv")
 
43
  return meta_data_df, gr.Dropdown.update(choices=meta_data_df["Name"].to_list())
44
 
45
 
46
+ def get_tree(exp_num: int = 1, tree_num: int = 1):
47
+ tree_num = max(tree_num, 1)
48
+ tree_path = os.path.join(
49
+ global_logs_dir, f"tree_{int(exp_num)}_{int(tree_num)}.txt"
50
+ )
51
+ if os.path.exists(tree_path):
52
+ with open(tree_path, "r", encoding="utf-8") as fp:
53
+ tree = fp.read()
54
+
55
+ return tree
56
+ print(f"Tree at {tree_path} not found!")
57
+ return ""
58
+
59
+
60
+ def change_expression(exp_num: int = 1, tree_num: int = 1):
61
+ paths = glob(os.path.join(global_logs_dir, f"tree_{int(exp_num)}_*.txt"))
62
+ tree_num = max(min(len(paths), tree_num), 1)
63
+
64
+ tree = get_tree(exp_num, tree_num)
65
+
66
+ return tree, gr.Slider.update(value=tree_num, maximum=len(paths))
67
+
68
+
69
  def find_expression(dataset_name: str):
70
+ global global_logs_dir
71
  meta_data_df = pd.read_csv(meta_data_file, index_col="Name")
72
  n_channels = meta_data_df["Channels"][dataset_name]
73
  data_dir = meta_data_df["Path"][dataset_name]
 
76
  mask_dir = os.path.join(data_dir, "masks")
77
 
78
  cache_dir = os.path.join(data_dir, "cache")
79
+ global_logs_dir = logs_dir = os.path.join(data_dir, "logs")
80
  models_dir = os.path.join(data_dir, "models")
81
  for dir_name in (cache_dir, logs_dir, models_dir):
82
  Path(dir_name).mkdir(parents=True, exist_ok=True)
 
84
  action_list = (
85
  list("()+-*/=") + ["sq", "sqrt"] + [f"c{c}" for c in range(n_channels)]
86
  )
87
+ env = IndexRLEnv(action_list, max_exp_len)
88
  agent, optimizer = create_model(len(action_list))
89
  seen_path = os.path.join(cache_dir, "seen.pkl") if cache_dir else ""
90
  env.save_seen(seen_path)
 
103
  1,
104
  logs_dir,
105
  seen_path,
106
+ tree_prefix=f"tree_{int(i)}",
107
  n_iters=1000,
108
  )
109
  print(
 
123
  with open(f"{cache_dir}/data_buffer_{i_str}.pkl", "wb") as fp:
124
  pickle.dump(data_buffer, fp)
125
 
126
+ tree = get_tree()
 
127
 
128
  top_5 = data_buffer.get_top_n(5)
129
  top_5_str = "\n".join(
 
135
  )
136
  )
137
 
138
+ yield tree, top_5_str, gr.Slider.update(
139
+ value=i, maximum=i, interactive=True
140
+ ), gr.Slider.update(value=1, maximum=len(data[-1][0]), interactive=True)
141
 
142
 
143
  with gr.Blocks(title="IndexRL") as demo:
 
145
  meta_data_df = pd.read_csv(meta_data_file)
146
 
147
  with gr.Tab("Find Expressions"):
148
+ with gr.Row():
149
+ with gr.Column():
150
+ select_dataset = gr.Dropdown(
151
+ label="Select Dataset",
152
+ choices=meta_data_df["Name"].to_list(),
153
+ )
154
+ find_exp_btn = gr.Button("Find Expressions", variant="primary")
155
+ stop_btn = gr.Button("Stop", variant="stop")
156
+ best_exps = gr.Textbox(label="Best Expressions", interactive=False)
157
+
158
+ with gr.Column():
159
+ select_exp = gr.Slider(
160
+ value=1, label="Iteration", interactive=False, minimum=1, step=1
161
+ )
162
+ select_tree = gr.Slider(
163
+ value=1, label="Tree Number", interactive=False, minimum=1, step=1
164
+ )
165
+ out_exp_tree = gr.Textbox(
166
+ label="Latest Expression Tree", interactive=False
167
+ )
168
 
169
  with gr.Tab("Datasets"):
170
  dataset_upload = gr.File(label="Upload Data ZIP file")
 
174
  dataset_table = gr.Dataframe(meta_data_df, label="Dataset Table")
175
 
176
  find_exp_event = find_exp_btn.click(
177
+ find_expression,
178
+ inputs=[select_dataset],
179
+ outputs=[out_exp_tree, best_exps, select_exp, select_tree],
180
  )
181
  stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[find_exp_event])
182
+ select_exp.change(
183
+ fn=lambda x, y: change_expression(x, y),
184
+ inputs=[select_exp, select_tree],
185
+ outputs=[out_exp_tree, select_tree],
186
+ )
187
+ select_tree.change(
188
+ fn=lambda x, y: get_tree(x, y),
189
+ inputs=[select_exp, select_tree],
190
+ outputs=out_exp_tree,
191
+ )
192
 
193
  dataset_upload.upload(
194
  lambda x: ".".join(os.path.basename(x.orig_name).split(".")[:-1]),