IvanStudent commited on
Commit
8b82ee3
·
verified ·
1 Parent(s): cfcecf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -137
app.py CHANGED
@@ -1,147 +1,88 @@
1
- import gradio as gr
2
  import pandas as pd
3
- import numpy as np
4
- import plotly.graph_objects as go
5
- import joblib # Para cargar el modelo ARIMA guardado
6
- import os # Para manejar rutas de archivos
7
-
8
- # Función para cargar el modelo ARIMA guardado
9
- def load_arima_model():
10
- model = joblib.load('arima_sales_model.pkl') # Cargar el modelo ARIMA desde un archivo guardado (.pkl)
11
- return model
12
-
13
- # Cargar el modelo ARIMA al inicio
14
- arima_model = load_arima_model()
15
-
16
- # Funciones de preprocesamiento y cálculo de crecimiento de ventas
17
- def drop(dataframe):
18
- pass # Implementar según sea necesario
19
-
20
- def date_format(dataframe):
21
- pass # Implementar según sea necesario
22
-
23
- def group_to_three(dataframe):
24
- pass # Implementar según sea necesario
25
-
26
- def get_forecast_period(period):
27
- return period # Retornar el periodo de pronóstico
28
-
29
- def sales_growth(df, forecasted_series):
30
- return forecasted_series.diff() # Calcular el crecimiento de ventas
31
-
32
- def merge_forecast_data(actual, predicted, future):
33
- return pd.DataFrame({
34
- "Actual Sales": actual,
35
- "Predicted Sales": predicted,
36
- "Forecasted Future Sales": future
37
- })
38
-
39
- # Función para mostrar una alerta si el archivo no es CSV o si excede el tamaño
40
- def check_file(uploaded_file):
41
  if uploaded_file is None:
42
- return gr.Error("⚠️ No file uploaded. Please upload a CSV file.")
43
 
44
- # Verificar si el archivo es CSV
45
- if not uploaded_file.endswith('.csv'):
46
- return gr.Error("⚠️ Invalid file format. Please upload a CSV file.")
 
 
 
47
 
48
- # Verificar el tamaño del archivo (200MB)
49
- file_size = uploaded_file.size # Verificar el tamaño del archivo
50
- if file_size > 200 * 1024 * 1024: # Limitar a 200MB
51
- return gr.Error("⚠️ File size exceeds the 200MB limit. Please upload a smaller file.")
52
 
53
- return None # No hay error si el archivo es válido
 
54
 
55
- # Función principal para la carga de archivo y la predicción
56
- def upload_and_forecast(uploaded_file, period):
57
- # Verificar si el archivo cargado es válido
58
- error_message = check_file(uploaded_file)
59
- if error_message:
60
- return error_message
61
 
62
- # Leer y procesar el archivo CSV
63
- df = pd.read_csv(uploaded_file) # Leer el archivo CSV
64
- df = drop(df)
65
- df = date_format(df)
66
- series = group_to_three(df)
67
 
68
- # Realizar la predicción con el modelo ARIMA
69
- forecast_period = get_forecast_period(period)
70
- forecasted_values, confint = arima_model.predict(n_periods=forecast_period, return_conf_int=True)
71
-
72
- # Crear serie con los valores pronosticados
73
- forecasted_series = pd.Series(forecasted_values)
74
- forecasted_series.index = pd.date_range(df['Date'].iloc[-1], periods=forecast_period, freq='3D')
75
-
76
- # Calcular el crecimiento de las ventas
77
- future_sales_growth = sales_growth(df, forecasted_series)
78
-
79
- # Combinar los datos para graficar
80
- merged_data = merge_forecast_data(df['Sales'], series, forecasted_series)
81
-
82
- # Crear gráficos
83
- fig_compare = go.Figure()
84
- fig_compare.add_trace(go.Scatter(x=merged_data[merged_data.columns[0]], y=merged_data['Actual Sales'], mode='lines', name='Actual Sales'))
85
- fig_compare.add_trace(go.Scatter(x=merged_data[merged_data.columns[0]], y=merged_data['Predicted Sales'], mode='lines', name='Predicted Sales', line=dict(color='#006400')))
86
- fig_compare.update_layout(title='📊 Historical Sales Data', xaxis_title='Date', yaxis_title='Sales')
87
-
88
- fig_forecast = go.Figure()
89
- fig_forecast.add_trace(go.Scatter(x=merged_data[merged_data.columns[0]], y=merged_data['Actual Sales'], mode='lines', name='Actual Sales'))
90
- fig_forecast.add_trace(go.Scatter(x=merged_data[merged_data.columns[0]], y=forecasted_series, mode='lines', name='Forecasted Sales'))
91
- fig_forecast.update_layout(title='🔮 Forecasted Sales Data', xaxis_title='Date', yaxis_title='Sales')
92
-
93
- return fig_compare, fig_forecast, future_sales_growth
94
-
95
- # Interfaz de Gradio
96
- def create_sidebar():
97
- with gr.Column():
98
- # Personalización del componente de carga de archivos
99
- gr.Markdown("### 📂 Upload your sales data (CSV)")
100
- uploaded_file = gr.File(
101
- label="Choose your file",
102
- elem_id="file-uploader",
103
- type="filepath", # Cambiado a 'filepath' para que retorne la ruta del archivo
104
- file_count="single", # Permite solo un archivo a la vez
105
- file_types=[".csv"], # Limita solo a archivos CSV
106
- interactive=True, # Hacer interactivo el componente para arrastrar y soltar
107
- )
108
-
109
- # Botón para cargar el periodo de pronóstico
110
- gr.Markdown("### ⏳ Forecast Period (Days)")
111
- period = gr.Slider(minimum=30, maximum=90, step=1, label="Forecast period (in days)")
112
-
113
- # Ruta del archivo de ejemplo
114
- sample_file_path = "sample_data.csv" # Ruta del archivo de ejemplo
115
- # Verifica si el archivo existe, de lo contrario lo crea
116
- if not os.path.exists(sample_file_path):
117
- sample_data = pd.DataFrame({
118
- "Date": ["2023-01-01", "2023-01-02", "2023-01-03"],
119
- "Sales": [100, 200, 300]
120
- })
121
- sample_data.to_csv(sample_file_path, index=False) # Crea el archivo de ejemplo si no existe
122
 
123
- # Usamos `gr.Markdown` para crear un enlace de descarga del archivo de ejemplo
124
- gr.Markdown(f"[Download our sample CSV](./{sample_file_path})") # Enlace directo para descargar el archivo
125
-
126
- return uploaded_file, period
127
-
128
- # Crear el sidebar y la interfaz principal
129
- uploaded_file, period = create_sidebar()
130
-
131
- output_plots = [
132
- gr.Plot(label="📈 Historical vs Predicted Sales"),
133
- gr.Plot(label="🔮 Forecasted Sales Data"),
134
- gr.DataFrame(label="📊 Sales Growth")
135
- ]
136
-
137
- iface = gr.Interface(
138
- fn=upload_and_forecast,
139
- inputs=[uploaded_file, period],
140
- outputs=output_plots,
141
- live=True,
142
- title="Sales Forecasting System ✨",
143
- description="Upload your sales data to start forecasting 🚀",
144
- css=open("styles.css", "r").read() # Cargar el archivo CSS para los estilos personalizados
145
- )
 
 
 
 
 
 
146
 
147
- iface.launch() # Lanzar la interfaz
 
 
 
 
1
  import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import joblib
4
+ import gradio as gr
5
+ from dateutil.relativedelta import relativedelta
6
+ import calendar
7
+
8
+ def load_model():
9
+ try:
10
+ model = joblib.load('arima_sales_model.pkl')
11
+ return model, None
12
+ except Exception as e:
13
+ return None, f"Failed to load model: {str(e)}"
14
+
15
+ def parse_date(date_str):
16
+ """Parse the custom date format 'Month-Year'."""
17
+ try:
18
+ date = pd.to_datetime(date_str, format="%B-%Y")
19
+ _, last_day = calendar.monthrange(date.year, date.month)
20
+ start_date = date.replace(day=1)
21
+ end_date = date.replace(day=last_day)
22
+ return start_date, end_date, None
23
+ except ValueError:
24
+ return None, None, "Date format should be 'Month-Year', e.g., 'January-2024'."
25
+
26
+ def forecast_sales(uploaded_file, start_date_str, end_date_str):
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if uploaded_file is None:
28
+ return "No file uploaded.", None, "Please upload a file."
29
 
30
+ try:
31
+ df = pd.read_csv(uploaded_file)
32
+ if 'Date' not in df.columns or 'Sale' not in df.columns:
33
+ return None, "The uploaded file must contain 'Date' and 'Sale' columns.", "File does not have required columns."
34
+ except Exception as e:
35
+ return None, f"Failed to read the uploaded CSV file: {str(e)}", "Error reading file."
36
 
37
+ start_date, _, error = parse_date(start_date_str)
38
+ _, end_date, error_end = parse_date(end_date_str)
39
+ if error or error_end:
40
+ return None, error or error_end, "Invalid date format."
41
 
42
+ df['Date'] = pd.to_datetime(df['Date'])
43
+ df = df.rename(columns={'Date': 'ds', 'Sale': 'y'})
44
 
45
+ df_filtered = df[(df['ds'] >= start_date) & (df['ds'] <= end_date)]
 
 
 
 
 
46
 
47
+ arima_model, error = load_model()
48
+ if arima_model is None:
49
+ return None, error, "Failed to load ARIMA model."
 
 
50
 
51
+ try:
52
+ forecast = arima_model.get_forecast(steps=60)
53
+ forecast_index = pd.date_range(start=end_date, periods=61, freq='D')[1:]
54
+ forecast_df = pd.DataFrame({'Date': forecast_index, 'Sales Forecast': forecast.predicted_mean})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ fig, ax = plt.subplots(figsize=(10, 6))
57
+ ax.plot(df_filtered['ds'], df_filtered['y'], label='Actual Sales', color='blue')
58
+ ax.plot(forecast_df['Date'], forecast_df['Sales Forecast'], label='Sales Forecast', color='red', linestyle='--')
59
+ ax.set_xlabel('Date')
60
+ ax.set_ylabel('Sales')
61
+ ax.set_title('Sales Forecasting with ARIMA')
62
+ ax.legend()
63
+ return fig, "File loaded and processed successfully."
64
+ except Exception as e:
65
+ return None, f"Failed to generate plot: {str(e)}", "Plotting failed."
66
+
67
+ def setup_interface():
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("## MLCast v1.1 - Intelligent Sales Forecasting System")
70
+ with gr.Row():
71
+ with gr.Column(scale=1):
72
+ file_input = gr.File(label="Upload your store data")
73
+ start_date_input = gr.Textbox(label="Start Date", placeholder="January-2024")
74
+ end_date_input = gr.Textbox(label="End Date", placeholder="December-2024")
75
+ forecast_button = gr.Button("Forecast Sales")
76
+ with gr.Column(scale=2):
77
+ output_plot = gr.Plot()
78
+ output_message = gr.Textbox(label="Notifications", visible=True, lines=2)
79
+ forecast_button.click(
80
+ forecast_sales,
81
+ inputs=[file_input, start_date_input, end_date_input],
82
+ outputs=[output_plot, output_message]
83
+ )
84
+ return demo
85
 
86
+ if __name__ == "__main__":
87
+ interface = setup_interface()
88
+ interface.launch()