Spaces:
Build error
Build error
| import functools | |
| import traceback | |
| from copy import deepcopy | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib.widgets import Button | |
| from omegaconf import OmegaConf | |
| from ..datasets.base_dataset import collate | |
| # from ..eval.export_predictions import load_predictions | |
| from ..models.cache_loader import CacheLoader | |
| from .tools import RadioHideTool | |
| class GlobalFrame: | |
| default_conf = { | |
| "x": "???", | |
| "y": "???", | |
| "diff": False, | |
| "child": {}, | |
| "remove_outliers": False, | |
| } | |
| child_frame = None # MatchFrame | |
| childs = [] | |
| lines = [] | |
| scatters = {} | |
| def __init__( | |
| self, conf, results, loader, predictions, title=None, child_frame=None | |
| ): | |
| self.child_frame = child_frame | |
| if self.child_frame is not None: | |
| # We do NOT merge inside the child frame to keep settings across figs | |
| self.default_conf["child"] = self.child_frame.default_conf | |
| self.conf = OmegaConf.merge(self.default_conf, conf) | |
| self.results = results | |
| self.loader = loader | |
| self.predictions = predictions | |
| self.metrics = set() | |
| for k, v in results.items(): | |
| self.metrics.update(v.keys()) | |
| self.metrics = sorted(list(self.metrics)) | |
| self.conf.x = conf["x"] if conf["x"] else self.metrics[0] | |
| self.conf.y = conf["y"] if conf["y"] else self.metrics[1] | |
| assert self.conf.x in self.metrics | |
| assert self.conf.y in self.metrics | |
| self.names = list(results) | |
| self.fig, self.axes = self.init_frame() | |
| if title is not None: | |
| self.fig.canvas.manager.set_window_title(title) | |
| self.xradios = self.fig.canvas.manager.toolmanager.add_tool( | |
| "x", | |
| RadioHideTool, | |
| options=self.metrics, | |
| callback_fn=self.update_x, | |
| active=self.conf.x, | |
| keymap="x", | |
| ) | |
| self.yradios = self.fig.canvas.manager.toolmanager.add_tool( | |
| "y", | |
| RadioHideTool, | |
| options=self.metrics, | |
| callback_fn=self.update_y, | |
| active=self.conf.y, | |
| keymap="y", | |
| ) | |
| if self.fig.canvas.manager.toolbar is not None: | |
| self.fig.canvas.manager.toolbar.add_tool("x", "navigation") | |
| self.fig.canvas.manager.toolbar.add_tool("y", "navigation") | |
| def init_frame(self): | |
| """initialize frame""" | |
| fig, ax = plt.subplots() | |
| ax.set_title("click on points") | |
| diffb_ax = fig.add_axes([0.01, 0.02, 0.12, 0.06]) | |
| self.diffb = Button(diffb_ax, label="diff_only") | |
| self.diffb.on_clicked(self.diff_clicked) | |
| fig.canvas.mpl_connect("pick_event", self.on_scatter_pick) | |
| fig.canvas.mpl_connect("motion_notify_event", self.hover) | |
| return fig, ax | |
| def draw(self): | |
| """redraw content in frame""" | |
| self.scatters = {} | |
| self.axes.clear() | |
| self.axes.set_xlabel(self.conf.x) | |
| self.axes.set_ylabel(self.conf.y) | |
| refx = 0.0 | |
| refy = 0.0 | |
| x_cat = isinstance(self.results[self.names[0]][self.conf.x][0], (bytes, str)) | |
| y_cat = isinstance(self.results[self.names[0]][self.conf.y][0], (bytes, str)) | |
| if self.conf.diff: | |
| if not x_cat: | |
| refx = np.array(self.results[self.names[0]][self.conf.x]) | |
| if not y_cat: | |
| refy = np.array(self.results[self.names[0]][self.conf.y]) | |
| for name in list(self.results.keys()): | |
| x = np.array(self.results[name][self.conf.x]) | |
| y = np.array(self.results[name][self.conf.y]) | |
| if x_cat and np.char.isdigit(x.astype(str)).all(): | |
| x = x.astype(int) | |
| if y_cat and np.char.isdigit(y.astype(str)).all(): | |
| y = y.astype(int) | |
| x = x if x_cat else x - refx | |
| y = y if y_cat else y - refy | |
| (s,) = self.axes.plot( | |
| x, y, "o", markersize=3, label=name, picker=True, pickradius=5 | |
| ) | |
| self.scatters[name] = s | |
| if x_cat and not y_cat: | |
| xunique, ind, xinv, xbin = np.unique( | |
| x, return_inverse=True, return_counts=True, return_index=True | |
| ) | |
| ybin = np.bincount(xinv, weights=y) | |
| sort_ax = np.argsort(ind) | |
| self.axes.step( | |
| xunique[sort_ax], | |
| (ybin / xbin)[sort_ax], | |
| where="mid", | |
| color=s.get_color(), | |
| ) | |
| if not x_cat: | |
| xavg = np.nan_to_num(x).mean() | |
| self.axes.axvline(xavg, c=s.get_color(), zorder=1, alpha=1.0) | |
| xmed = np.median(x - refx) | |
| self.axes.axvline( | |
| xmed, | |
| c=s.get_color(), | |
| zorder=0, | |
| alpha=0.5, | |
| linestyle="dashed", | |
| visible=False, | |
| ) | |
| if not y_cat: | |
| yavg = np.nan_to_num(y).mean() | |
| self.axes.axhline(yavg, c=s.get_color(), zorder=1, alpha=0.5) | |
| ymed = np.median(y - refy) | |
| self.axes.axhline( | |
| ymed, | |
| c=s.get_color(), | |
| zorder=0, | |
| alpha=0.5, | |
| linestyle="dashed", | |
| visible=False, | |
| ) | |
| if x_cat and x.dtype == object and xunique.shape[0] > 5: | |
| self.axes.set_xticklabels(xunique[sort_ax], rotation=90) | |
| self.axes.legend() | |
| def on_scatter_pick(self, handle): | |
| try: | |
| art = handle.artist | |
| try: | |
| event = handle.mouseevent.button.value | |
| except AttributeError: | |
| return | |
| name = art.get_label() | |
| ind = handle.ind[0] | |
| # draw lines | |
| self.spawn_child(name, ind, event=event) | |
| except Exception: | |
| traceback.print_exc() | |
| exit(0) | |
| def spawn_child(self, model_name, ind, event=None): | |
| [line.remove() for line in self.lines] | |
| self.lines = [] | |
| x_source = self.scatters[model_name].get_xdata()[ind] | |
| y_source = self.scatters[model_name].get_ydata()[ind] | |
| for oname in self.names: | |
| xn = self.scatters[oname].get_xdata()[ind] | |
| yn = self.scatters[oname].get_ydata()[ind] | |
| (ln,) = self.axes.plot([x_source, xn], [y_source, yn], "r") | |
| self.lines.append(ln) | |
| self.fig.canvas.draw_idle() | |
| if self.child_frame is None: | |
| return | |
| data = collate([self.loader.dataset[ind]]) | |
| preds = {} | |
| for name, pfile in self.predictions.items(): | |
| preds[name] = CacheLoader({"path": str(pfile), "add_data_path": False})( | |
| data | |
| ) | |
| summaries_i = { | |
| name: {k: v[ind] for k, v in res.items() if k != "names"} | |
| for name, res in self.results.items() | |
| } | |
| frame = self.child_frame( | |
| self.conf.child, | |
| deepcopy(data), | |
| preds, | |
| title=str(data["name"][0]), | |
| event=event, | |
| summaries=summaries_i, | |
| ) | |
| frame.fig.canvas.mpl_connect( | |
| "key_press_event", | |
| functools.partial( | |
| self.on_childframe_key_event, frame=frame, ind=ind, event=event | |
| ), | |
| ) | |
| self.childs.append(frame) | |
| # if plt.rcParams['backend'] == 'webagg': | |
| # self.fig.canvas.manager_class.refresh_all() | |
| self.childs[-1].fig.show() | |
| def hover(self, event): | |
| if event.inaxes == self.axes: | |
| for _, s in self.scatters.items(): | |
| cont, ind = s.contains(event) | |
| if cont: | |
| ind = ind["ind"][0] | |
| xdata, ydata = s.get_data() | |
| [line.remove() for line in self.lines] | |
| self.lines = [] | |
| for oname in self.names: | |
| xn = self.scatters[oname].get_xdata()[ind] | |
| yn = self.scatters[oname].get_ydata()[ind] | |
| (ln,) = self.axes.plot( | |
| [xdata[ind], xn], | |
| [ydata[ind], yn], | |
| "black", | |
| zorder=0, | |
| alpha=0.5, | |
| ) | |
| self.lines.append(ln) | |
| self.fig.canvas.draw_idle() | |
| break | |
| def diff_clicked(self, args): | |
| self.conf.diff = not self.conf.diff | |
| self.draw() | |
| self.fig.canvas.draw_idle() | |
| def update_x(self, x): | |
| self.conf.x = x | |
| self.draw() | |
| def update_y(self, y): | |
| self.conf.y = y | |
| self.draw() | |
| def on_childframe_key_event(self, key_event, frame, ind, event): | |
| if key_event.key == "delete": | |
| plt.close(frame.fig) | |
| self.childs.remove(frame) | |
| elif key_event.key in ["left", "right", "shift+left", "shift+right"]: | |
| key = key_event.key | |
| if key.startswith("shift+"): | |
| key = key.replace("shift+", "") | |
| else: | |
| plt.close(frame.fig) | |
| self.childs.remove(frame) | |
| new_ind = ind + 1 if key_event.key == "right" else ind - 1 | |
| self.spawn_child( | |
| self.names[0], | |
| new_ind % len(self.loader), | |
| event=event, | |
| ) | |