import pandas as pd import matplotlib.pyplot as plt from statsmodels.tsa.statespace.sarimax import SARIMAX import os script_dir = os.path.dirname(os.path.abspath(__file__)) csv_path = os.path.join(script_dir, 'sample_data_extended.csv') df = pd.read_csv(csv_path) def preprocess_data(df): df = df.dropna() # Set the time column as the index df['Time'] = pd.to_datetime(df['Time']) df = df.set_index('Time') # convert string to float and remove commas df['Total expense (VND)'] = df['Total expense (VND)'].str.replace(',', '').astype(float) df['Income (VND)'] = df['Income (VND)'].str.replace(',', '').astype(float) return df def plot_data(df): plt.figure(figsize=(10, 6)) plt.plot(df['Total expense (VND)'], label='Total expense (VND)') plt.plot(df['Income (VND)'], label='Income (VND)') plt.legend(loc='best') plt.show() def fit_model(df): # define the output and exxogenous variables y = df['Total expense (VND)'] exog = df[['Income (VND)', 'Interest rate (%)', 'Inflation rate (%)', 'Holidays']] # define orders for the model order = (1, 1, 1) seasonal_order = (1, 1, 1, 12) # fit the model model = SARIMAX(y, exog=exog, order=order, seasonal_order=seasonal_order) model_fit = model.fit() return model_fit def get_input_data(income, interest_rate, inflation_rate, holidays): income = float(income) interest_rate = float(interest_rate) inflation_rate = float(inflation_rate) holidays = int(holidays) input_data = pd.DataFrame({ 'Income (VND)': [income], 'Interest rate (%)': [interest_rate], 'Inflation rate (%)': [inflation_rate], 'Holidays': [holidays] }) return input_data def forecast_expense(model_fit, input_data, df): if (input_data['Income (VND)'].iloc[0] < 5000000): return input_data['Income (VND)'].iloc[0] * 0.78492 forecast = model_fit.predict(start=len(df), end=len(df), exog=input_data) return forecast.iloc[0] df = preprocess_data(df) model_fit = fit_model(df) # def main(): # global df # df = preprocess_data(df) # model_fit = fit_model(df) # input_data = get_input_data(10000000, 5, 3, 0) # forecast = forecast_expense(model_fit, input_data) # print(forecast) # if __name__ == '__main__': # main()