Spaces:
Build error
Build error
applying gr.State
Browse files
app.py
CHANGED
|
@@ -239,7 +239,7 @@ class Experiment(Component):
|
|
| 239 |
idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
|
| 240 |
return metric_info[1][idx]
|
| 241 |
|
| 242 |
-
def generate_record(self, data_id, metric_names):
|
| 243 |
record = {}
|
| 244 |
_base = self.experiment.run_batch([data_id], 0, 0, 0)
|
| 245 |
record['data_id'] = data_id
|
|
@@ -252,7 +252,7 @@ class Experiment(Component):
|
|
| 252 |
metrics_ids = [self.get_metric_id_by_name(metric_nm) for metric_nm in metric_names]
|
| 253 |
|
| 254 |
cnt = 0
|
| 255 |
-
for info in
|
| 256 |
if info['checked']:
|
| 257 |
base = self.experiment.run_batch([data_id], info['id'], info['pp_id'], 0)
|
| 258 |
record['explanations'].append({
|
|
@@ -334,9 +334,9 @@ class Experiment(Component):
|
|
| 334 |
plot = gr.Image(value=None, label="Blank", visible=False)
|
| 335 |
plots.append(plot)
|
| 336 |
|
| 337 |
-
def show_plots():
|
| 338 |
_plots = [gr.Textbox(label="Prediction result", visible=False)]
|
| 339 |
-
num_plots = sum([1 for info in
|
| 340 |
n_rows = num_plots // PLOT_PER_LINE
|
| 341 |
n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
|
| 342 |
_plots += [gr.Image(value=None, label="Blank", visible=True)] * (n_rows * PLOT_PER_LINE)
|
|
@@ -344,7 +344,7 @@ class Experiment(Component):
|
|
| 344 |
return _plots
|
| 345 |
|
| 346 |
@spaces.GPU
|
| 347 |
-
def render_plots(data_id, *metric_inputs):
|
| 348 |
# Clear Cache Files
|
| 349 |
# print(f"GPU Check: {torch.cuda.is_available()}")
|
| 350 |
# print("Which GPU: ", torch.cuda.current_device())
|
|
@@ -360,12 +360,15 @@ class Experiment(Component):
|
|
| 360 |
if metric:
|
| 361 |
metric_input += metric
|
| 362 |
|
| 363 |
-
record = self.generate_record(data_id, metric_input)
|
| 364 |
|
| 365 |
pred = self.get_prediction(record)
|
| 366 |
plots = [gr.Textbox(label="Prediction result", value=pred, visible=True)]
|
| 367 |
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
| 369 |
n_rows = num_plots // PLOT_PER_LINE
|
| 370 |
n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
|
| 371 |
|
|
@@ -383,8 +386,8 @@ class Experiment(Component):
|
|
| 383 |
|
| 384 |
return plots
|
| 385 |
|
| 386 |
-
bttn.click(show_plots, outputs=plots)
|
| 387 |
-
bttn.click(render_plots, inputs=[data_id] + metric_inputs, outputs=plots)
|
| 388 |
|
| 389 |
|
| 390 |
|
|
@@ -397,30 +400,33 @@ class ExplainerCheckboxGroup(Component):
|
|
| 397 |
self.gallery = gallery
|
| 398 |
explainers, exp_ids = self.experiment.manager.get_explainers()
|
| 399 |
|
| 400 |
-
|
| 401 |
for exp, exp_id in zip(explainers, exp_ids):
|
| 402 |
exp_nm = exp.__class__.__name__
|
| 403 |
if exp_nm in DEFAULT_EXPLAINER:
|
| 404 |
checked = True
|
| 405 |
else:
|
| 406 |
checked = False
|
| 407 |
-
|
|
|
|
|
|
|
| 408 |
|
| 409 |
-
def update_check(self, exp_id, val=None):
|
| 410 |
-
for info in
|
| 411 |
if info['id'] == exp_id:
|
| 412 |
if val is not None:
|
| 413 |
info['checked'] = val
|
| 414 |
else:
|
| 415 |
info['checked'] = not info['checked']
|
|
|
|
| 416 |
|
| 417 |
-
def insert_check(self, exp_nm, exp_id, pp_id):
|
| 418 |
-
if exp_id in [info['id'] for info in
|
| 419 |
return
|
|
|
|
|
|
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
def update_gallery_change(self):
|
| 424 |
checkboxes = []
|
| 425 |
bttns = []
|
| 426 |
for exp in self.explainer_objs:
|
|
@@ -431,10 +437,10 @@ class ExplainerCheckboxGroup(Component):
|
|
| 431 |
|
| 432 |
for exp in self.explainer_objs:
|
| 433 |
val = exp.explainer_name in DEFAULT_EXPLAINER
|
| 434 |
-
self.update_check(exp.default_exp_id, val)
|
| 435 |
if hasattr(exp, "optimal_exp_id"):
|
| 436 |
-
self.update_check(exp.optimal_exp_id, False)
|
| 437 |
-
return checkboxes + bttns
|
| 438 |
|
| 439 |
def get_checkboxes(self):
|
| 440 |
checkboxes = []
|
|
@@ -447,11 +453,10 @@ class ExplainerCheckboxGroup(Component):
|
|
| 447 |
|
| 448 |
def show(self):
|
| 449 |
cnt = 0
|
| 450 |
-
sorted_info = sorted(self.info, key=lambda x: (x['nm'] not in DEFAULT_EXPLAINER, x['nm']))
|
| 451 |
with gr.Accordion("Explainers", open=True):
|
| 452 |
while cnt * PLOT_PER_LINE < len(self.explainer_names):
|
| 453 |
with gr.Row():
|
| 454 |
-
for info in
|
| 455 |
explainer_obj = ExplainerCheckbox(info['nm'], self, self.experiment, self.gallery)
|
| 456 |
self.explainer_objs.append(explainer_obj)
|
| 457 |
explainer_obj.show()
|
|
@@ -461,7 +466,8 @@ class ExplainerCheckboxGroup(Component):
|
|
| 461 |
bttns = self.get_bttns()
|
| 462 |
self.gallery.gallery_obj.select(
|
| 463 |
fn=self.update_gallery_change,
|
| 464 |
-
|
|
|
|
| 465 |
)
|
| 466 |
|
| 467 |
|
|
@@ -488,28 +494,31 @@ class ExplainerCheckbox(Component):
|
|
| 488 |
def get_str_ppid(self, pp_obj):
|
| 489 |
return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
|
| 490 |
|
| 491 |
-
def default_on_select(self, evt: gr.EventData):
|
| 492 |
-
self.groups.update_check(self.default_exp_id, evt._data['value'])
|
|
|
|
| 493 |
|
| 494 |
-
def optimal_on_select(self, evt: gr.EventData):
|
| 495 |
if hasattr(self, "optimal_exp_id"):
|
| 496 |
-
self.groups.update_check(self.optimal_exp_id, evt._data['value'])
|
| 497 |
else:
|
| 498 |
raise ValueError("Optimal explainer id is not found.")
|
|
|
|
| 499 |
|
| 500 |
def show(self):
|
| 501 |
val = self.explainer_name in DEFAULT_EXPLAINER
|
| 502 |
with gr.Accordion(self.explainer_name, open=val):
|
| 503 |
-
checked = next(filter(lambda x: x['nm'] == self.explainer_name, self.groups.
|
| 504 |
self.default_check = gr.Checkbox(label="Default Parameter", value=checked, interactive=True)
|
| 505 |
self.opt_check = gr.Checkbox(label="Optimized Parameter (Not Optimal)", interactive=False)
|
| 506 |
|
| 507 |
-
self.default_check.select(self.default_on_select)
|
| 508 |
-
self.opt_check.select(self.optimal_on_select)
|
| 509 |
|
| 510 |
self.bttn = gr.Button(value="Optimize", size="sm", variant="primary")
|
| 511 |
|
| 512 |
-
|
|
|
|
| 513 |
data_id = self.gallery.selected_index
|
| 514 |
|
| 515 |
opt_output = self.experiment.optimize(
|
|
@@ -527,18 +536,18 @@ class ExplainerCheckbox(Component):
|
|
| 527 |
opt_postprocessor_id = pp_id
|
| 528 |
break
|
| 529 |
|
| 530 |
-
opt_explainer_id = max([x['id'] for x in
|
| 531 |
opt_output.explainer.model = self.experiment.model
|
| 532 |
self.experiment.manager._explainers.append(opt_output.explainer)
|
| 533 |
self.experiment.manager._explainer_ids.append(opt_explainer_id)
|
| 534 |
-
self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
|
| 535 |
self.optimal_exp_id = opt_explainer_id
|
| 536 |
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
|
| 537 |
bttn = gr.update(value="Optimized", variant="secondary")
|
| 538 |
|
| 539 |
-
return [checkbox, bttn]
|
| 540 |
|
| 541 |
-
self.bttn.click(optimize, outputs=[self.opt_check, self.bttn], queue=True, concurrency_limit=1)
|
| 542 |
|
| 543 |
|
| 544 |
class ExpRes(Component):
|
|
@@ -692,3 +701,4 @@ experiments['experiment2'] = {
|
|
| 692 |
app = ImageClsApp(experiments)
|
| 693 |
demo = app.launch()
|
| 694 |
demo.launch(favicon_path=f"static/XAI-Top-PnP.svg")
|
|
|
|
|
|
| 239 |
idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
|
| 240 |
return metric_info[1][idx]
|
| 241 |
|
| 242 |
+
def generate_record(self, checkbox_group_info, data_id, metric_names):
|
| 243 |
record = {}
|
| 244 |
_base = self.experiment.run_batch([data_id], 0, 0, 0)
|
| 245 |
record['data_id'] = data_id
|
|
|
|
| 252 |
metrics_ids = [self.get_metric_id_by_name(metric_nm) for metric_nm in metric_names]
|
| 253 |
|
| 254 |
cnt = 0
|
| 255 |
+
for info in checkbox_group_info:
|
| 256 |
if info['checked']:
|
| 257 |
base = self.experiment.run_batch([data_id], info['id'], info['pp_id'], 0)
|
| 258 |
record['explanations'].append({
|
|
|
|
| 334 |
plot = gr.Image(value=None, label="Blank", visible=False)
|
| 335 |
plots.append(plot)
|
| 336 |
|
| 337 |
+
def show_plots(checkbox_group_info):
|
| 338 |
_plots = [gr.Textbox(label="Prediction result", visible=False)]
|
| 339 |
+
num_plots = sum([1 for info in checkbox_group_info if info['checked']])
|
| 340 |
n_rows = num_plots // PLOT_PER_LINE
|
| 341 |
n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
|
| 342 |
_plots += [gr.Image(value=None, label="Blank", visible=True)] * (n_rows * PLOT_PER_LINE)
|
|
|
|
| 344 |
return _plots
|
| 345 |
|
| 346 |
@spaces.GPU
|
| 347 |
+
def render_plots(data_id, checkbox_group_info, *metric_inputs):
|
| 348 |
# Clear Cache Files
|
| 349 |
# print(f"GPU Check: {torch.cuda.is_available()}")
|
| 350 |
# print("Which GPU: ", torch.cuda.current_device())
|
|
|
|
| 360 |
if metric:
|
| 361 |
metric_input += metric
|
| 362 |
|
| 363 |
+
record = self.generate_record(checkbox_group_info, data_id, metric_input)
|
| 364 |
|
| 365 |
pred = self.get_prediction(record)
|
| 366 |
plots = [gr.Textbox(label="Prediction result", value=pred, visible=True)]
|
| 367 |
|
| 368 |
+
# for info in checkbox_group_info:
|
| 369 |
+
# if info['checked']:
|
| 370 |
+
# print(info)
|
| 371 |
+
num_plots = sum([1 for info in checkbox_group_info if info['checked']])
|
| 372 |
n_rows = num_plots // PLOT_PER_LINE
|
| 373 |
n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
|
| 374 |
|
|
|
|
| 386 |
|
| 387 |
return plots
|
| 388 |
|
| 389 |
+
bttn.click(show_plots, inputs=[self.explainer_checkbox_group.info], outputs=plots)
|
| 390 |
+
bttn.click(render_plots, inputs=[data_id, self.explainer_checkbox_group.info] + metric_inputs, outputs=plots)
|
| 391 |
|
| 392 |
|
| 393 |
|
|
|
|
| 400 |
self.gallery = gallery
|
| 401 |
explainers, exp_ids = self.experiment.manager.get_explainers()
|
| 402 |
|
| 403 |
+
info = []
|
| 404 |
for exp, exp_id in zip(explainers, exp_ids):
|
| 405 |
exp_nm = exp.__class__.__name__
|
| 406 |
if exp_nm in DEFAULT_EXPLAINER:
|
| 407 |
checked = True
|
| 408 |
else:
|
| 409 |
checked = False
|
| 410 |
+
info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : 0, 'mode': 'default', 'checked': checked})
|
| 411 |
+
self.static_info = sorted(info, key=lambda x: (x['nm'] not in DEFAULT_EXPLAINER, x['nm']))
|
| 412 |
+
self.info = gr.State(info)
|
| 413 |
|
| 414 |
+
def update_check(self, checkbox_group_info, exp_id, val=None):
|
| 415 |
+
for info in checkbox_group_info:
|
| 416 |
if info['id'] == exp_id:
|
| 417 |
if val is not None:
|
| 418 |
info['checked'] = val
|
| 419 |
else:
|
| 420 |
info['checked'] = not info['checked']
|
| 421 |
+
return checkbox_group_info
|
| 422 |
|
| 423 |
+
def insert_check(self, checkbox_group_info, exp_nm, exp_id, pp_id):
|
| 424 |
+
if exp_id in [info['id'] for info in checkbox_group_info]:
|
| 425 |
return
|
| 426 |
+
checkbox_group_info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : pp_id, 'mode': 'optimal', 'checked': False})
|
| 427 |
+
return checkbox_group_info
|
| 428 |
|
| 429 |
+
def update_gallery_change(self, checkbox_group_info):
|
|
|
|
|
|
|
| 430 |
checkboxes = []
|
| 431 |
bttns = []
|
| 432 |
for exp in self.explainer_objs:
|
|
|
|
| 437 |
|
| 438 |
for exp in self.explainer_objs:
|
| 439 |
val = exp.explainer_name in DEFAULT_EXPLAINER
|
| 440 |
+
checkbox_group_info = self.update_check(checkbox_group_info, exp.default_exp_id, val)
|
| 441 |
if hasattr(exp, "optimal_exp_id"):
|
| 442 |
+
checkbox_group_info = self.update_check(checkbox_group_info, exp.optimal_exp_id, False)
|
| 443 |
+
return checkboxes + bttns + [checkbox_group_info]
|
| 444 |
|
| 445 |
def get_checkboxes(self):
|
| 446 |
checkboxes = []
|
|
|
|
| 453 |
|
| 454 |
def show(self):
|
| 455 |
cnt = 0
|
|
|
|
| 456 |
with gr.Accordion("Explainers", open=True):
|
| 457 |
while cnt * PLOT_PER_LINE < len(self.explainer_names):
|
| 458 |
with gr.Row():
|
| 459 |
+
for info in self.static_info[cnt*PLOT_PER_LINE:(cnt+1)*PLOT_PER_LINE]:
|
| 460 |
explainer_obj = ExplainerCheckbox(info['nm'], self, self.experiment, self.gallery)
|
| 461 |
self.explainer_objs.append(explainer_obj)
|
| 462 |
explainer_obj.show()
|
|
|
|
| 466 |
bttns = self.get_bttns()
|
| 467 |
self.gallery.gallery_obj.select(
|
| 468 |
fn=self.update_gallery_change,
|
| 469 |
+
inputs=self.info,
|
| 470 |
+
outputs=checkboxes + bttns + [self.info],
|
| 471 |
)
|
| 472 |
|
| 473 |
|
|
|
|
| 494 |
def get_str_ppid(self, pp_obj):
|
| 495 |
return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
|
| 496 |
|
| 497 |
+
def default_on_select(self, evt: gr.EventData, checkbox_group_info):
|
| 498 |
+
checkbox_group_info = self.groups.update_check(checkbox_group_info, self.default_exp_id, evt._data['value'])
|
| 499 |
+
return checkbox_group_info
|
| 500 |
|
| 501 |
+
def optimal_on_select(self, evt: gr.EventData, checkbox_group_info):
|
| 502 |
if hasattr(self, "optimal_exp_id"):
|
| 503 |
+
checkbox_group_info = self.groups.update_check(checkbox_group_info, self.optimal_exp_id, evt._data['value'])
|
| 504 |
else:
|
| 505 |
raise ValueError("Optimal explainer id is not found.")
|
| 506 |
+
return checkbox_group_info
|
| 507 |
|
| 508 |
def show(self):
|
| 509 |
val = self.explainer_name in DEFAULT_EXPLAINER
|
| 510 |
with gr.Accordion(self.explainer_name, open=val):
|
| 511 |
+
checked = next(filter(lambda x: x['nm'] == self.explainer_name, self.groups.static_info))['checked']
|
| 512 |
self.default_check = gr.Checkbox(label="Default Parameter", value=checked, interactive=True)
|
| 513 |
self.opt_check = gr.Checkbox(label="Optimized Parameter (Not Optimal)", interactive=False)
|
| 514 |
|
| 515 |
+
self.default_check.select(self.default_on_select, self.groups.info, self.groups.info)
|
| 516 |
+
self.opt_check.select(self.optimal_on_select, self.groups.info, self.groups.info)
|
| 517 |
|
| 518 |
self.bttn = gr.Button(value="Optimize", size="sm", variant="primary")
|
| 519 |
|
| 520 |
+
@spaces.GPU
|
| 521 |
+
def optimize(checkbox_group_info):
|
| 522 |
data_id = self.gallery.selected_index
|
| 523 |
|
| 524 |
opt_output = self.experiment.optimize(
|
|
|
|
| 536 |
opt_postprocessor_id = pp_id
|
| 537 |
break
|
| 538 |
|
| 539 |
+
opt_explainer_id = max([x['id'] for x in checkbox_group_info]) + 1
|
| 540 |
opt_output.explainer.model = self.experiment.model
|
| 541 |
self.experiment.manager._explainers.append(opt_output.explainer)
|
| 542 |
self.experiment.manager._explainer_ids.append(opt_explainer_id)
|
| 543 |
+
self.groups.insert_check(checkbox_group_info, self.explainer_name, opt_explainer_id, opt_postprocessor_id)
|
| 544 |
self.optimal_exp_id = opt_explainer_id
|
| 545 |
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
|
| 546 |
bttn = gr.update(value="Optimized", variant="secondary")
|
| 547 |
|
| 548 |
+
return [checkbox_group_info, checkbox, bttn]
|
| 549 |
|
| 550 |
+
self.bttn.click(optimize, inputs=[self.groups.info], outputs=[self.groups.info, self.opt_check, self.bttn], queue=True, concurrency_limit=1)
|
| 551 |
|
| 552 |
|
| 553 |
class ExpRes(Component):
|
|
|
|
| 701 |
app = ImageClsApp(experiments)
|
| 702 |
demo = app.launch()
|
| 703 |
demo.launch(favicon_path=f"static/XAI-Top-PnP.svg")
|
| 704 |
+
# demo.launch(favicon_path=f"static/XAI-Top-PnP.svg", share=True)
|