|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
import os |
|
import logging |
|
import time |
|
|
|
class ChartGenerator: |
|
def __init__(self, data=None): |
|
logging.info("Initializing ChartGenerator") |
|
if data is not None: |
|
self.data = data |
|
else: |
|
self.data = pd.read_excel(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'sample_data.xlsx')) |
|
|
|
def generate_chart(self, plot_args): |
|
start_time = time.time() |
|
logging.info(f"Generating chart with arguments: {plot_args}") |
|
|
|
|
|
x_col = plot_args['x'] |
|
y_cols = plot_args['y'] |
|
missing_cols = [] |
|
if x_col not in self.data.columns: |
|
missing_cols.append(x_col) |
|
for y in y_cols: |
|
if y not in self.data.columns: |
|
missing_cols.append(y) |
|
if missing_cols: |
|
logging.error(f"Missing columns in data: {missing_cols}") |
|
logging.info(f"Available columns: {list(self.data.columns)}") |
|
raise ValueError(f"Missing columns in data: {missing_cols}") |
|
|
|
|
|
plt.clf() |
|
plt.close('all') |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
|
for y in y_cols: |
|
color = plot_args.get('color', None) |
|
if plot_args.get('chart_type', 'line') == 'bar': |
|
ax.bar(self.data[x_col], self.data[y], label=y, color=color) |
|
else: |
|
ax.plot(self.data[x_col], self.data[y], label=y, color=color, marker='o') |
|
|
|
ax.set_xlabel(x_col) |
|
ax.set_ylabel('Value') |
|
ax.set_title(f'{plot_args.get("chart_type", "line").title()} Chart') |
|
ax.legend() |
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
if len(self.data[x_col]) > 5: |
|
plt.xticks(rotation=45) |
|
|
|
chart_filename = 'chart.png' |
|
output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static', 'images') |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
logging.info(f"Created output directory: {output_dir}") |
|
|
|
full_path = os.path.join(output_dir, chart_filename) |
|
|
|
if os.path.exists(full_path): |
|
os.remove(full_path) |
|
logging.info(f"Removed existing chart file: {full_path}") |
|
|
|
|
|
plt.savefig(full_path, dpi=300, bbox_inches='tight', facecolor='white') |
|
plt.close(fig) |
|
|
|
|
|
if os.path.exists(full_path): |
|
file_size = os.path.getsize(full_path) |
|
logging.info(f"Chart generated and saved to {full_path} (size: {file_size} bytes)") |
|
else: |
|
logging.error(f"Failed to create chart file at {full_path}") |
|
raise FileNotFoundError(f"Chart file was not created at {full_path}") |
|
|
|
return os.path.join('static', 'images', chart_filename) |