| from typing import List, Tuple | |
| import matplotlib.pyplot as plt | |
| from matplotlib.axes import Axes | |
| import pandas as pd | |
| def plot_metrics( | |
| metrics: List[Tuple[pd.Series, pd.Series, str]] | List[List[Tuple[pd.Series, pd.Series, str]]], | |
| remove_na: bool = False, | |
| subAxes: Axes = None, | |
| title: str = None, | |
| xlabel: str = None, | |
| ylabel: str = None, | |
| figsize=(8, 6)): | |
| _, axes = plt.subplots(len(metrics), 1, figsize=( | |
| figsize[0], figsize[1] * len(metrics))) if subAxes is None else (None, subAxes) | |
| for index, metric in enumerate(metrics): | |
| ax = (axes[index] if len(metrics) > | |
| 1 else axes) if subAxes is None else subAxes | |
| if type(metric) is tuple: | |
| (x, y, legend) = metric[0:3] | |
| color = metric[3] if len(metric) > 3 else 'blue' | |
| [x, y] = [x, y] if not remove_na else zip( | |
| *[[x_1, y_1] for x_1, y_1 in zip(x, y) if pd.notna(y_1)]) | |
| ax.plot(x, y, color=color, label=legend) | |
| ax.legend() | |
| else: | |
| plot_metrics(metric, remove_na, ax) | |
| plt.title(title) | |
| plt.xlabel(xlabel) | |
| plt.ylabel(ylabel) |