|
import io |
|
import tempfile |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from PIL.ImageFile import ImageFile |
|
from matplotlib import pyplot as plt |
|
from smolagents.tools import Tool |
|
|
|
|
|
def _plot_line_diagram( |
|
x_values: list, |
|
y_values_list: list[list[int | float]], |
|
labels: list[str] | None = None, |
|
title: str = 'Bar Diagram', |
|
xlabel: str = 'X-axis', |
|
ylabel: str = 'Y-axis' |
|
) -> str: |
|
""" |
|
Plot a line diagram with one or more y-values and save the image to a temporary file. |
|
Return the path to the saved image file. |
|
|
|
:param x_values: List of x-values. |
|
:param y_values_list: List of lists containing y-values. |
|
Each inner list represents a separate line. |
|
:param labels: List of labels for each line (optional). |
|
:param title: Title of the plot (default: 'Line Diagram'). |
|
:param xlabel: Label for the X-axis (default: 'X-axis'). |
|
:param ylabel: Label for the Y-axis (default: 'Y-axis'). |
|
:return: Path to the saved image file. |
|
""" |
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
|
for i, y_values in enumerate(y_values_list): |
|
label = labels[i] if labels and i < len(labels) else f'Line {i+1}' |
|
plt.plot(x_values, y_values, label=label) |
|
|
|
plt.title(title) |
|
plt.xlabel(xlabel) |
|
plt.ylabel(ylabel) |
|
plt.legend() |
|
plt.grid(True) |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
plt.savefig(temp_file.name) |
|
plt.close() |
|
|
|
return temp_file.name |
|
|
|
|
|
def _plot_bar_diagram( |
|
x_values: list, |
|
y_values_list: list[list[int | float]], |
|
labels: list[str] | None = None, |
|
title: str = 'Bar Diagram', |
|
xlabel: str = 'X-axis', |
|
ylabel: str = 'Y-axis' |
|
) -> str: |
|
""" |
|
Plot a bar diagram with one or more y-values and save the image to a temporary file. |
|
Return the path to the saved image file. |
|
|
|
:param x_values: List of x-values. |
|
:param y_values_list: List of lists containing y-values. |
|
Each inner list represents a separate set of bars. |
|
:param labels: List of labels for each set of bars (optional). |
|
:param title: Title of the plot (default: 'Bar Diagram'). |
|
:param xlabel: Label for the X-axis (default: 'X-axis'). |
|
:param ylabel: Label for the Y-axis (default: 'Y-axis'). |
|
:return: Path to the saved image file. |
|
""" |
|
|
|
bar_width = 0.2 |
|
n = len(y_values_list) |
|
|
|
|
|
r = [np.arange(len(x_values))] |
|
for i in range(1, n): |
|
r.append([x + bar_width for x in r[i - 1]]) |
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
|
for i, y_values in enumerate(y_values_list): |
|
label = labels[i] if labels and i < len(labels) else f'Set {i + 1}' |
|
plt.bar(r[i], y_values, width=bar_width, label=label) |
|
|
|
|
|
plt.xlabel(xlabel) |
|
plt.ylabel(ylabel) |
|
plt.title(title) |
|
plt.xticks( |
|
[r + bar_width * (n - 1) / 2 for r in range(len(x_values))], |
|
x_values, |
|
rotation=45, |
|
ha='right' |
|
) |
|
plt.legend() |
|
plt.grid(True, axis='y') |
|
plt.tight_layout() |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
plt.savefig(temp_file.name) |
|
plt.close() |
|
|
|
return temp_file.name |
|
|
|
|
|
class PlotTool(Tool): |
|
""" |
|
A tool to plot bar and line diagrams and return the images. |
|
""" |
|
|
|
name = 'plot_bar_line_diagrams' |
|
description = ( |
|
'Plot a bar or line diagram with one or more y-values and save the image.' |
|
' Return the saved image file as `ImageFile`.' |
|
' An agent must take this `ImageFile` and display the image.' |
|
) |
|
inputs = { |
|
'plot_type': { |
|
'type': 'string', |
|
'description': 'The type of plot. Only two values are valid: `bar` and `line`' |
|
}, |
|
'x_values': {'type': 'array', 'description': 'List of x-values.'}, |
|
'y_values_list': { |
|
'type': 'array', |
|
'description': ( |
|
'A list of lists containing y-values (numbers).' |
|
' Each inner list represents a separate set of bars.' |
|
' The input type is list[list[int | float]]' |
|
) |
|
}, |
|
'labels': { |
|
'type': 'array', |
|
'nullable': True, |
|
'description': ( |
|
'A list of labels for each set of bars (optional). Defaults to `None`.' |
|
' If provided, the length of `labels` must be equal to the length of `x_values`.' |
|
) |
|
}, |
|
'title': { |
|
'type': 'string', |
|
'nullable': True, |
|
'description': 'Title of the plot (default: "Bar Diagram")' |
|
}, |
|
'xlabel': { |
|
'type': 'string', |
|
'nullable': True, |
|
'description': 'Label for the X-axis (default: "X-axis")' |
|
}, |
|
'ylabel': { |
|
'type': 'string', |
|
'nullable': True, |
|
'description': 'Label for the Y-axis (default: "Y-axis").' |
|
}, |
|
} |
|
output_type = 'image' |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__() |
|
|
|
def forward( |
|
self, |
|
plot_type: str, |
|
x_values: list, |
|
y_values_list: list[list[int | float]], |
|
labels: list[str] | None = None, |
|
title: str = 'Bar Diagram', |
|
xlabel: str = 'X-axis', |
|
ylabel: str = 'Y-axis' |
|
) -> ImageFile: |
|
|
|
if plot_type == 'bar': |
|
img_file_name = _plot_bar_diagram( |
|
x_values=x_values, |
|
y_values_list=y_values_list, |
|
labels=labels, |
|
xlabel=xlabel, |
|
ylabel=ylabel, |
|
title=title |
|
) |
|
elif plot_type == 'line': |
|
img_file_name = _plot_line_diagram( |
|
x_values=x_values, |
|
y_values_list=y_values_list, |
|
labels=labels, |
|
xlabel=xlabel, |
|
ylabel=ylabel, |
|
title=title |
|
) |
|
else: |
|
img_file_name = None |
|
|
|
if img_file_name: |
|
with open(img_file_name, 'rb') as in_file: |
|
return Image.open(io.BytesIO(in_file.read())) |
|
else: |
|
return None |
|
|