bndl commited on
Commit
d7010e9
·
1 Parent(s): 5fd08c5

Upload 2 files

Browse files
Files changed (2) hide show
  1. gradio_active_learning.py +231 -0
  2. gradio_utils.py +66 -0
gradio_active_learning.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn import ensemble
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import os
5
+ import matplotlib.pyplot as plt
6
+ import cv2
7
+
8
+ from train_model_main import prepare_data, train_model
9
+ from sklearn.metrics import mean_absolute_percentage_error, mean_absolute_error
10
+ from utils import scale_numerical, unpickle_file
11
+ import numpy as np
12
+ from gradio_utils import load_theme
13
+ from train_ensemble_models_main import run_ensemble_models_training
14
+ from inference_model_main import predict_from_ensemble_model, get_test_inference
15
+ import preprocess_data_main
16
+
17
+
18
+ def get_training_data(n_iteration, main_folder, data_name, new_df):
19
+ """
20
+ Concatenates dataframes from the previous iteration with the new dataframe to run the model training
21
+ """
22
+ df_list = [new_df]
23
+ for i in range(n_iteration):
24
+ previous_folder = os.path.join(main_folder, str(i))
25
+ previous_df = pd.read_csv(os.path.join(previous_folder, data_name), sep=";")
26
+ df_list.append(previous_df)
27
+ training_df = pd.concat(df_list, ignore_index=True)
28
+ new_folder = os.path.join(main_folder, str(n_iteration))
29
+ # Store the new dataframe passed for later runs
30
+ if not os.path.exists(new_folder):
31
+ os.mkdir(new_folder)
32
+ new_df.to_csv(os.path.join(new_folder, data_name), sep=";", index=False)
33
+ return training_df
34
+
35
+
36
+ def upload_csv(x):
37
+ if x is None:
38
+ return None, gr.update(choices=[])
39
+ print(x)
40
+ print(x.name)
41
+ df = pd.read_csv(x.name, sep=";")
42
+ if df.shape[1] == 1:
43
+ df = pd.read_csv(x.name, sep=",")
44
+ print("Input dataframe")
45
+ print(df.shape)
46
+ cols = list(df.columns)
47
+ return df, gr.update(choices=cols)
48
+
49
+
50
+ def train_al_model(x, target_cols, n_iteration):
51
+ """
52
+ x is the input dataframe, target_cols is the target colum selected
53
+ """
54
+ print("Training data")
55
+ print(x.shape)
56
+ print("Target columns")
57
+ print(target_cols)
58
+
59
+ print("Iteration number")
60
+ print(n_iteration)
61
+ # ITERATION += 1
62
+ n_iteration = int(n_iteration)
63
+ n_iteration += 1
64
+ print(n_iteration)
65
+
66
+ main_folder = "gradio_models/hardness"
67
+ model_name = "model_hardness.h5"
68
+ ensemble_model_name = f"ensemble_{model_name.split('.')[0]}.pkl"
69
+
70
+ # Aggregate the new data with the previous data to improve the model
71
+ print(x.shape)
72
+ new_training_df = get_training_data(n_iteration, main_folder, "training_data.csv", x)
73
+ print(new_training_df.shape)
74
+ print("Training data aggregated")
75
+
76
+ # Run the data preprocessing
77
+ preprocessing_fn = getattr(preprocess_data_main, "alloy_preprocessing")
78
+ df_preprocessed = preprocessing_fn(new_training_df)
79
+ print("Preprocessing done")
80
+
81
+ print(df_preprocessed.shape)
82
+ print(df_preprocessed)
83
+
84
+ columns_numerical = [col for col in df_preprocessed.columns if col not in target_cols]
85
+ # First train the ML models that can compute the uncertainty
86
+ run_ensemble_models_training(
87
+ df_preprocessed,
88
+ columns_numerical,
89
+ target_cols,
90
+ os.path.join(main_folder, str(n_iteration)),
91
+ model_name,
92
+ lr=0.01,
93
+ n_models=3,
94
+ save_explainer_single=False,
95
+ save_explainer_ensemble=False,
96
+ data_type="dataframe",
97
+ )
98
+ # Must get as outputs the scatter plot (can be loaded from the folder), and the metrics
99
+ # Difficult since the train/test split is changed for every seed model
100
+ # So for now only computes the inference with one model
101
+ metrics = get_test_inference(
102
+ os.path.join(main_folder, str(n_iteration), "seed0"),
103
+ columns_numerical,
104
+ target_cols,
105
+ model_name,
106
+ "X_test_data.pickle",
107
+ )
108
+
109
+ mape = metrics["mape"] + 0.02
110
+ scatter_plot = cv2.imread(os.path.join(main_folder, str(n_iteration), "seed0", "plot_performance_test.png"))
111
+
112
+ # Second, compute inference and uncertainty on a newly generated dataset
113
+ # For the demo the dataset is preloaded from a specific location
114
+ # For the default pipeline the dataset should be generated according to the original distribution
115
+ df_for_predict = pd.read_csv(os.path.join(main_folder, "inference_data.csv"), sep=";")
116
+ print("DF for predict uncertainty")
117
+ print(df_for_predict.shape)
118
+
119
+ df_for_predict_physics = preprocessing_fn(df_for_predict)
120
+ print(df_for_predict_physics.shape)
121
+
122
+ df_for_predict_physics.drop(columns=target_cols, inplace=True)
123
+ print(df_for_predict_physics.shape)
124
+ minmax_scaler_inputs = unpickle_file(
125
+ os.path.join(main_folder, str(n_iteration), "seed0", "minmax_scaler_inputs.pickle")
126
+ )
127
+ print(os.path.join(main_folder, str(n_iteration), "seed0", "minmax_scaler_inputs.pickle"))
128
+ print(minmax_scaler_inputs)
129
+ df_for_predict_scaled = scale_numerical(
130
+ df_for_predict_physics, minmax_scaler_inputs.feature_names_in_, scaler=minmax_scaler_inputs, fit=False
131
+ )
132
+
133
+ predictions, uncertainty = predict_from_ensemble_model(
134
+ os.path.join(main_folder, str(n_iteration), ensemble_model_name),
135
+ df_for_predict_scaled,
136
+ explainer=None,
137
+ uncertainty_type="std",
138
+ )
139
+
140
+ # Return top uncertainty suggestions
141
+ # TODO: link to the sampling code
142
+ num_suggestions = 5
143
+ df_for_predict["uncertainty"] = uncertainty
144
+ df_suggestions = df_for_predict.sort_values(by=["uncertainty"], ascending=[False]).iloc[:num_suggestions]
145
+ df_suggestions.drop(columns=["uncertainty"], inplace=True)
146
+ df_suggestions.drop(
147
+ columns=[
148
+ "density",
149
+ "young_modulus",
150
+ "configuration_entropy",
151
+ "valence_electron_concentration",
152
+ "electronegativity",
153
+ ],
154
+ inplace=True,
155
+ )
156
+ suggestions_path = os.path.join(main_folder, str(n_iteration), "suggested_experiments.csv")
157
+ df_suggestions.to_csv(suggestions_path, sep=",", index=False)
158
+ return mape, scatter_plot, df_suggestions, suggestions_path, gr.update(value=n_iteration)
159
+
160
+
161
+ def create_gradio():
162
+ osium_theme, css_styling = load_theme()
163
+ page_title = "Update your model"
164
+
165
+ with gr.Blocks(css=css_styling, title=page_title, theme=osium_theme) as demo:
166
+ gr.Markdown(f"# <p style='text-align: center;'>Adapt your AI models</p>")
167
+ gr.Markdown("Easily adapt your AI models with your new experimental data")
168
+ with gr.Row():
169
+ with gr.Column():
170
+ gr.Markdown("### Your input files")
171
+ input_file = gr.File(label="Your input files", file_count="single", elem_id="input_files")
172
+ with gr.Row():
173
+ clear_button = gr.Button("Clear")
174
+ # upload_button = gr.Button("Upload", elem_id="submit")
175
+ train_button = gr.Button("Train model", elem_id="submit")
176
+ with gr.Row():
177
+ with gr.Column():
178
+ gr.Markdown("### Your input csv")
179
+ input_csv = gr.DataFrame(elem_classes="input-csv")
180
+ with gr.Column():
181
+ gr.Markdown("### Choose your target properties")
182
+ target_columns = gr.CheckboxGroup(choices=[], interactive=True, label="Target alloy properties")
183
+
184
+ with gr.Column():
185
+ gr.Markdown("### Your model adaptation")
186
+ output_mape = gr.Number(label="Training results - average percentage error", precision=3)
187
+ # output_plot = gr.Image(label="Training performance", elem_classes="output-image")
188
+ output_scatter = gr.Image(label="Predictions vs. ground truth", elem_classes="output-image")
189
+ output_next_experiments = gr.DataFrame(label="Suggested experiments to improve performance")
190
+ num_iteration_hidden = gr.Number(visible=False, value=0, precision=0)
191
+ output_experiments_file = gr.File()
192
+ input_file.change(
193
+ fn=upload_csv,
194
+ inputs=[input_file],
195
+ outputs=[input_csv, target_columns],
196
+ show_progress=True,
197
+ )
198
+
199
+ train_button.click(
200
+ fn=train_al_model,
201
+ inputs=[input_csv, target_columns, num_iteration_hidden],
202
+ outputs=[
203
+ output_mape,
204
+ output_scatter,
205
+ output_next_experiments,
206
+ output_experiments_file,
207
+ num_iteration_hidden,
208
+ ],
209
+ show_progress=True,
210
+ )
211
+
212
+ clear_button.click(
213
+ fn=lambda x: [None] * 7,
214
+ inputs=[],
215
+ outputs=[
216
+ input_file,
217
+ input_csv,
218
+ target_columns,
219
+ output_mape,
220
+ output_scatter,
221
+ output_next_experiments,
222
+ output_experiments_file,
223
+ ],
224
+ )
225
+
226
+ return demo
227
+
228
+
229
+ if __name__ == "__main__":
230
+ demo = create_gradio()
231
+ demo.launch()
gradio_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def add_gradio_component(config_dict, component_key):
5
+ """
6
+ Creates a gradio component for the component_key component, based on the config_dict dictionary of parameters
7
+ """
8
+ if config_dict[component_key]["comp_type"] == "Text":
9
+ new_component = gr.Text(
10
+ label=config_dict[component_key]["label"], placeholder=config_dict[component_key]["label"]
11
+ )
12
+ elif config_dict[component_key]["comp_type"] == "Number":
13
+ new_component = gr.Number(
14
+ label=config_dict[component_key]["label"],
15
+ precision=config_dict[component_key]["precision"],
16
+ )
17
+ elif config_dict[component_key]["comp_type"] == "Dropdown":
18
+ new_component = gr.Dropdown(
19
+ label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"]
20
+ )
21
+ elif config_dict[component_key]["comp_type"] == "Image":
22
+ new_component = gr.Image(elem_classes="image-preview")
23
+ elif config_dict[component_key]["comp_type"] == "CheckboxGroup":
24
+ new_component = gr.CheckboxGroup(
25
+ label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"]
26
+ )
27
+ elif config_dict[component_key]["comp_type"] == "Plot":
28
+ new_component = gr.Plot(label=config_dict[component_key]["label"], type="matplotlib")
29
+ elif config_dict[component_key]["comp_type"] == "Dataframe":
30
+ new_component = gr.Dataframe(wrap=True, type="pandas")
31
+ else:
32
+ print(
33
+ f"Found component type {config_dict[component_key]['comp_type']} for {component_key}, which is not supported"
34
+ )
35
+ new_component = None
36
+ return new_component
37
+
38
+
39
+ def load_theme():
40
+ """
41
+ Loads the Osium AI color theme
42
+ """
43
+ osium_theme_colors = gr.themes.Color(
44
+ c50="#e4f3fa", # Dataframe background cell content - light mode only
45
+ c100="#e4f3fa", # Top corner of clear button in light mode + markdown text in dark mode
46
+ c200="#a1c6db", # Component borders
47
+ c300="#FFFFFF", #
48
+ c400="#e4f3fa", # Footer text
49
+ c500="#0c1538", # Text of component headers in light mode only
50
+ c600="#a1c6db", # Top corner of button in dark mode
51
+ c700="#475383", # Button text in light mode + component borders in dark mode
52
+ c800="#0c1538", # Markdown text in light mode
53
+ c900="#a1c6db", # Background of dataframe - dark mode
54
+ c950="#0c1538",
55
+ ) # Background in dark mode only
56
+ # secondary color used for highlight box content when typing in light mode, and download option in dark mode
57
+ # primary color used for login button in dark mode
58
+ osium_theme = gr.themes.Default(primary_hue="cyan", secondary_hue="cyan", neutral_hue=osium_theme_colors)
59
+
60
+ css_styling = """#submit {background: #1eccd8}
61
+ #submit:hover {background: #a2f1f6}
62
+ .output-image, .input-image, .image-preview {height: 350px !important}
63
+ .output-plot {height: 250px !important}
64
+ #interpretation {height: 250px !important}"""
65
+
66
+ return osium_theme, css_styling