Update app.py
Browse files
app.py
CHANGED
@@ -1,147 +1,88 @@
|
|
1 |
-
import gradio as gr
|
2 |
import pandas as pd
|
3 |
-
import
|
4 |
-
import
|
5 |
-
import
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
def
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
def sales_growth(df, forecasted_series):
|
30 |
-
return forecasted_series.diff() # Calcular el crecimiento de ventas
|
31 |
-
|
32 |
-
def merge_forecast_data(actual, predicted, future):
|
33 |
-
return pd.DataFrame({
|
34 |
-
"Actual Sales": actual,
|
35 |
-
"Predicted Sales": predicted,
|
36 |
-
"Forecasted Future Sales": future
|
37 |
-
})
|
38 |
-
|
39 |
-
# Función para mostrar una alerta si el archivo no es CSV o si excede el tamaño
|
40 |
-
def check_file(uploaded_file):
|
41 |
if uploaded_file is None:
|
42 |
-
return
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
if
|
51 |
-
return
|
52 |
|
53 |
-
|
|
|
54 |
|
55 |
-
|
56 |
-
def upload_and_forecast(uploaded_file, period):
|
57 |
-
# Verificar si el archivo cargado es válido
|
58 |
-
error_message = check_file(uploaded_file)
|
59 |
-
if error_message:
|
60 |
-
return error_message
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
df = date_format(df)
|
66 |
-
series = group_to_three(df)
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
# Crear serie con los valores pronosticados
|
73 |
-
forecasted_series = pd.Series(forecasted_values)
|
74 |
-
forecasted_series.index = pd.date_range(df['Date'].iloc[-1], periods=forecast_period, freq='3D')
|
75 |
-
|
76 |
-
# Calcular el crecimiento de las ventas
|
77 |
-
future_sales_growth = sales_growth(df, forecasted_series)
|
78 |
-
|
79 |
-
# Combinar los datos para graficar
|
80 |
-
merged_data = merge_forecast_data(df['Sales'], series, forecasted_series)
|
81 |
-
|
82 |
-
# Crear gráficos
|
83 |
-
fig_compare = go.Figure()
|
84 |
-
fig_compare.add_trace(go.Scatter(x=merged_data[merged_data.columns[0]], y=merged_data['Actual Sales'], mode='lines', name='Actual Sales'))
|
85 |
-
fig_compare.add_trace(go.Scatter(x=merged_data[merged_data.columns[0]], y=merged_data['Predicted Sales'], mode='lines', name='Predicted Sales', line=dict(color='#006400')))
|
86 |
-
fig_compare.update_layout(title='📊 Historical Sales Data', xaxis_title='Date', yaxis_title='Sales')
|
87 |
-
|
88 |
-
fig_forecast = go.Figure()
|
89 |
-
fig_forecast.add_trace(go.Scatter(x=merged_data[merged_data.columns[0]], y=merged_data['Actual Sales'], mode='lines', name='Actual Sales'))
|
90 |
-
fig_forecast.add_trace(go.Scatter(x=merged_data[merged_data.columns[0]], y=forecasted_series, mode='lines', name='Forecasted Sales'))
|
91 |
-
fig_forecast.update_layout(title='🔮 Forecasted Sales Data', xaxis_title='Date', yaxis_title='Sales')
|
92 |
-
|
93 |
-
return fig_compare, fig_forecast, future_sales_growth
|
94 |
-
|
95 |
-
# Interfaz de Gradio
|
96 |
-
def create_sidebar():
|
97 |
-
with gr.Column():
|
98 |
-
# Personalización del componente de carga de archivos
|
99 |
-
gr.Markdown("### 📂 Upload your sales data (CSV)")
|
100 |
-
uploaded_file = gr.File(
|
101 |
-
label="Choose your file",
|
102 |
-
elem_id="file-uploader",
|
103 |
-
type="filepath", # Cambiado a 'filepath' para que retorne la ruta del archivo
|
104 |
-
file_count="single", # Permite solo un archivo a la vez
|
105 |
-
file_types=[".csv"], # Limita solo a archivos CSV
|
106 |
-
interactive=True, # Hacer interactivo el componente para arrastrar y soltar
|
107 |
-
)
|
108 |
-
|
109 |
-
# Botón para cargar el periodo de pronóstico
|
110 |
-
gr.Markdown("### ⏳ Forecast Period (Days)")
|
111 |
-
period = gr.Slider(minimum=30, maximum=90, step=1, label="Forecast period (in days)")
|
112 |
-
|
113 |
-
# Ruta del archivo de ejemplo
|
114 |
-
sample_file_path = "sample_data.csv" # Ruta del archivo de ejemplo
|
115 |
-
# Verifica si el archivo existe, de lo contrario lo crea
|
116 |
-
if not os.path.exists(sample_file_path):
|
117 |
-
sample_data = pd.DataFrame({
|
118 |
-
"Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
|
119 |
-
"Sales": [100, 200, 300]
|
120 |
-
})
|
121 |
-
sample_data.to_csv(sample_file_path, index=False) # Crea el archivo de ejemplo si no existe
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
-
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import joblib
|
4 |
+
import gradio as gr
|
5 |
+
from dateutil.relativedelta import relativedelta
|
6 |
+
import calendar
|
7 |
+
|
8 |
+
def load_model():
|
9 |
+
try:
|
10 |
+
model = joblib.load('arima_sales_model.pkl')
|
11 |
+
return model, None
|
12 |
+
except Exception as e:
|
13 |
+
return None, f"Failed to load model: {str(e)}"
|
14 |
+
|
15 |
+
def parse_date(date_str):
|
16 |
+
"""Parse the custom date format 'Month-Year'."""
|
17 |
+
try:
|
18 |
+
date = pd.to_datetime(date_str, format="%B-%Y")
|
19 |
+
_, last_day = calendar.monthrange(date.year, date.month)
|
20 |
+
start_date = date.replace(day=1)
|
21 |
+
end_date = date.replace(day=last_day)
|
22 |
+
return start_date, end_date, None
|
23 |
+
except ValueError:
|
24 |
+
return None, None, "Date format should be 'Month-Year', e.g., 'January-2024'."
|
25 |
+
|
26 |
+
def forecast_sales(uploaded_file, start_date_str, end_date_str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
if uploaded_file is None:
|
28 |
+
return "No file uploaded.", None, "Please upload a file."
|
29 |
|
30 |
+
try:
|
31 |
+
df = pd.read_csv(uploaded_file)
|
32 |
+
if 'Date' not in df.columns or 'Sale' not in df.columns:
|
33 |
+
return None, "The uploaded file must contain 'Date' and 'Sale' columns.", "File does not have required columns."
|
34 |
+
except Exception as e:
|
35 |
+
return None, f"Failed to read the uploaded CSV file: {str(e)}", "Error reading file."
|
36 |
|
37 |
+
start_date, _, error = parse_date(start_date_str)
|
38 |
+
_, end_date, error_end = parse_date(end_date_str)
|
39 |
+
if error or error_end:
|
40 |
+
return None, error or error_end, "Invalid date format."
|
41 |
|
42 |
+
df['Date'] = pd.to_datetime(df['Date'])
|
43 |
+
df = df.rename(columns={'Date': 'ds', 'Sale': 'y'})
|
44 |
|
45 |
+
df_filtered = df[(df['ds'] >= start_date) & (df['ds'] <= end_date)]
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
arima_model, error = load_model()
|
48 |
+
if arima_model is None:
|
49 |
+
return None, error, "Failed to load ARIMA model."
|
|
|
|
|
50 |
|
51 |
+
try:
|
52 |
+
forecast = arima_model.get_forecast(steps=60)
|
53 |
+
forecast_index = pd.date_range(start=end_date, periods=61, freq='D')[1:]
|
54 |
+
forecast_df = pd.DataFrame({'Date': forecast_index, 'Sales Forecast': forecast.predicted_mean})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
57 |
+
ax.plot(df_filtered['ds'], df_filtered['y'], label='Actual Sales', color='blue')
|
58 |
+
ax.plot(forecast_df['Date'], forecast_df['Sales Forecast'], label='Sales Forecast', color='red', linestyle='--')
|
59 |
+
ax.set_xlabel('Date')
|
60 |
+
ax.set_ylabel('Sales')
|
61 |
+
ax.set_title('Sales Forecasting with ARIMA')
|
62 |
+
ax.legend()
|
63 |
+
return fig, "File loaded and processed successfully."
|
64 |
+
except Exception as e:
|
65 |
+
return None, f"Failed to generate plot: {str(e)}", "Plotting failed."
|
66 |
+
|
67 |
+
def setup_interface():
|
68 |
+
with gr.Blocks() as demo:
|
69 |
+
gr.Markdown("## MLCast v1.1 - Intelligent Sales Forecasting System")
|
70 |
+
with gr.Row():
|
71 |
+
with gr.Column(scale=1):
|
72 |
+
file_input = gr.File(label="Upload your store data")
|
73 |
+
start_date_input = gr.Textbox(label="Start Date", placeholder="January-2024")
|
74 |
+
end_date_input = gr.Textbox(label="End Date", placeholder="December-2024")
|
75 |
+
forecast_button = gr.Button("Forecast Sales")
|
76 |
+
with gr.Column(scale=2):
|
77 |
+
output_plot = gr.Plot()
|
78 |
+
output_message = gr.Textbox(label="Notifications", visible=True, lines=2)
|
79 |
+
forecast_button.click(
|
80 |
+
forecast_sales,
|
81 |
+
inputs=[file_input, start_date_input, end_date_input],
|
82 |
+
outputs=[output_plot, output_message]
|
83 |
+
)
|
84 |
+
return demo
|
85 |
|
86 |
+
if __name__ == "__main__":
|
87 |
+
interface = setup_interface()
|
88 |
+
interface.launch()
|