|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from statsmodels.tsa.arima.model import ARIMA |
|
import pickle |
|
import gradio as gr |
|
|
|
def load_model(): |
|
try: |
|
with open('arima_sales_model.pkl', 'rb') as f: |
|
arima_model = pickle.load(f) |
|
return arima_model |
|
except Exception as e: |
|
return None, f"Failed to load model: {str(e)}" |
|
|
|
def forecast_sales(uploaded_file, forecast_period=30): |
|
if uploaded_file is None: |
|
return "No file uploaded.", None |
|
|
|
try: |
|
df = pd.read_csv(uploaded_file) |
|
except Exception as e: |
|
return f"Failed to read the uploaded CSV file: {str(e)}", None |
|
|
|
if 'Date' not in df.columns or 'Sale' not in df.columns: |
|
return "The uploaded file must contain 'Date' and 'Sale' columns.", None |
|
|
|
try: |
|
df['Date'] = pd.to_datetime(df['Date']) |
|
df = df.rename(columns={'Date': 'ds', 'Sale': 'y'}) |
|
|
|
arima_model, error = load_model() |
|
if arima_model is None: |
|
return error, None |
|
|
|
forecast = arima_model.get_forecast(steps=forecast_period) |
|
forecast_index = pd.date_range(df['ds'].max(), periods=forecast_period + 1, freq='D')[1:] |
|
forecast_df = pd.DataFrame({'Date': forecast_index, 'Sales Forecast': forecast.predicted_mean}) |
|
except Exception as e: |
|
return f"Failed during forecasting: {str(e)}", None |
|
|
|
try: |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
ax.plot(df['ds'], df['y'], label='Historical Sales', color='blue') |
|
ax.plot(forecast_df['Date'], forecast_df['Sales Forecast'], label='Sales Forecast', color='red', linestyle='--') |
|
ax.set_xlabel('Date') |
|
ax.set_ylabel('Sales') |
|
ax.set_title('Sales Forecasting with ARIMA') |
|
ax.legend() |
|
return None, fig |
|
except Exception as e: |
|
return f"Failed to generate plot: {str(e)}", None |
|
|
|
def setup_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("## MLCast v1.1 - Intelligent Sales Forecasting System") |
|
file_input = gr.File(label="Upload your store data here (must contain Date and Sales)") |
|
forecast_button = gr.Button("Forecast Sales") |
|
output_plot = gr.Plot() |
|
output_text = gr.Textbox() |
|
forecast_button.click(forecast_sales, inputs=[file_input], outputs=[output_text, output_plot]) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
interface = setup_interface() |
|
interface.launch() |
|
|