Spaces:
Running
Running
| """Plotting utilities.""" | |
| import numpy as np | |
| from typing import Tuple | |
| from bokeh.layouts import column | |
| from bokeh.models import CustomJS, Slider | |
| from bokeh.plotting import figure, Figure, ColumnDataSource | |
| from bokeh.embed import components | |
| def barplot(attended: np.ndarray, weights: np.ndarray) -> Figure: | |
| """ | |
| Bokeh barplot showing top k attention weights. | |
| k is interactively changable via a slider. | |
| Args: | |
| attended (np.ndarray): Names of the attended entities | |
| weights (np.ndarray): Attention weights | |
| Returns: | |
| bokeh.plotting.Figure: Can be visualized for debugging, | |
| via bokeh.plotting (i.e. output_file, show) | |
| """ | |
| K = 4 | |
| # reset from slider callback | |
| source = ColumnDataSource( | |
| data=dict(attended=attended, weights=weights), | |
| ) | |
| top_k_slider = Slider(start=1, end=len(attended), value=K, step=1, title="k") | |
| p = figure( | |
| x_range=source.data["attended"][:K], # adapted by callback | |
| plot_height=350, | |
| title="Top k Gene Attention Weights", | |
| toolbar_location="below", | |
| tools="pan,wheel_zoom,box_zoom,save,reset", | |
| ) | |
| p.vbar(x="attended", top="weights", source=source, width=0.9) | |
| # define the callback | |
| callback = CustomJS( | |
| args=dict( | |
| source=source, | |
| xrange=p.x_range, | |
| yrange=p.y_range, | |
| attended=attended, | |
| weights=weights, | |
| top_k=top_k_slider, | |
| ), | |
| code=""" | |
| var data = source.data; | |
| const k = top_k.value; | |
| data['attended'] = attended.slice(0, k) | |
| data['weights'] = weights.slice(0, k) | |
| source.change.emit(); | |
| // not need if data is in descending order | |
| var yrange_arr = data['weights']; | |
| var yrange_max = Math.max(...yrange_arr) * 1.05; | |
| yrange.end = yrange_max; | |
| xrange.factors = data['attended']; | |
| source.change.emit(); | |
| """, | |
| ) | |
| top_k_slider.js_on_change("value", callback) | |
| layout = column(top_k_slider, p) | |
| p.xgrid.grid_line_color = None | |
| p.y_range.start = 0 | |
| return layout | |
| def embed_barplot(attended: np.ndarray, weights: np.ndarray) -> Tuple[str, str]: | |
| """Bokeh barplot showing top k attention weights. | |
| k is interactively changable via a slider. | |
| Args: | |
| attended (np.ndarray): Names of the attended entities | |
| weights (np.ndarray): Attention weights | |
| Returns: | |
| Tuple[str, str]: javascript and html | |
| """ | |
| return components(barplot(attended, weights)) | |