import io import tempfile import numpy as np from PIL import Image from PIL.ImageFile import ImageFile from matplotlib import pyplot as plt from smolagents.tools import Tool def _plot_line_diagram( x_values: list, y_values_list: list[list[int | float]], labels: list[str] | None = None, title: str = 'Bar Diagram', xlabel: str = 'X-axis', ylabel: str = 'Y-axis' ) -> str: """ Plot a line diagram with one or more y-values and save the image to a temporary file. Return the path to the saved image file. :param x_values: List of x-values. :param y_values_list: List of lists containing y-values. Each inner list represents a separate line. :param labels: List of labels for each line (optional). :param title: Title of the plot (default: 'Line Diagram'). :param xlabel: Label for the X-axis (default: 'X-axis'). :param ylabel: Label for the Y-axis (default: 'Y-axis'). :return: Path to the saved image file. """ plt.figure(figsize=(10, 6)) for i, y_values in enumerate(y_values_list): label = labels[i] if labels and i < len(labels) else f'Line {i+1}' plt.plot(x_values, y_values, label=label) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) plt.legend() plt.grid(True) # Save the plot as an image file in a temporary directory temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') plt.savefig(temp_file.name) plt.close() return temp_file.name def _plot_bar_diagram( x_values: list, y_values_list: list[list[int | float]], labels: list[str] | None = None, title: str = 'Bar Diagram', xlabel: str = 'X-axis', ylabel: str = 'Y-axis' ) -> str: """ Plot a bar diagram with one or more y-values and save the image to a temporary file. Return the path to the saved image file. :param x_values: List of x-values. :param y_values_list: List of lists containing y-values. Each inner list represents a separate set of bars. :param labels: List of labels for each set of bars (optional). :param title: Title of the plot (default: 'Bar Diagram'). :param xlabel: Label for the X-axis (default: 'X-axis'). :param ylabel: Label for the Y-axis (default: 'Y-axis'). :return: Path to the saved image file. """ bar_width = 0.2 n = len(y_values_list) # Set positions of bars on X axis r = [np.arange(len(x_values))] for i in range(1, n): r.append([x + bar_width for x in r[i - 1]]) plt.figure(figsize=(10, 6)) for i, y_values in enumerate(y_values_list): label = labels[i] if labels and i < len(labels) else f'Set {i + 1}' plt.bar(r[i], y_values, width=bar_width, label=label) # Adding xticks plt.xlabel(xlabel) plt.ylabel(ylabel) plt.title(title) plt.xticks( [r + bar_width * (n - 1) / 2 for r in range(len(x_values))], x_values, rotation=45, ha='right' ) plt.legend() plt.grid(True, axis='y') plt.tight_layout() # Save the plot as an image file in a temporary directory temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') plt.savefig(temp_file.name) plt.close() return temp_file.name class PlotTool(Tool): """ A tool to plot bar and line diagrams and return the images. """ name = 'plot_bar_line_diagrams' description = ( 'Plot a bar or line diagram with one or more y-values and save the image.' ' Return the saved image file as `ImageFile`.' ' An agent must take this `ImageFile` and display the image.' ) inputs = { 'plot_type': { 'type': 'string', 'description': 'The type of plot. Only two values are valid: `bar` and `line`' }, 'x_values': {'type': 'array', 'description': 'List of x-values.'}, 'y_values_list': { 'type': 'array', 'description': ( 'A list of lists containing y-values (numbers).' ' Each inner list represents a separate set of bars.' ' The input type is list[list[int | float]]' ) }, 'labels': { 'type': 'array', 'nullable': True, 'description': ( 'A list of labels for each set of bars (optional). Defaults to `None`.' ' If provided, the length of `labels` must be equal to the length of `x_values`.' ) }, 'title': { 'type': 'string', 'nullable': True, 'description': 'Title of the plot (default: "Bar Diagram")' }, 'xlabel': { 'type': 'string', 'nullable': True, 'description': 'Label for the X-axis (default: "X-axis")' }, 'ylabel': { 'type': 'string', 'nullable': True, 'description': 'Label for the Y-axis (default: "Y-axis").' }, } output_type = 'image' def __init__(self, **kwargs): super().__init__() def forward( self, plot_type: str, x_values: list, y_values_list: list[list[int | float]], labels: list[str] | None = None, title: str = 'Bar Diagram', xlabel: str = 'X-axis', ylabel: str = 'Y-axis' ) -> ImageFile: if plot_type == 'bar': img_file_name = _plot_bar_diagram( x_values=x_values, y_values_list=y_values_list, labels=labels, xlabel=xlabel, ylabel=ylabel, title=title ) elif plot_type == 'line': img_file_name = _plot_line_diagram( x_values=x_values, y_values_list=y_values_list, labels=labels, xlabel=xlabel, ylabel=ylabel, title=title ) else: img_file_name = None if img_file_name: with open(img_file_name, 'rb') as in_file: return Image.open(io.BytesIO(in_file.read())) else: return None