stockport9 / app.py
aiqtech's picture
Update app.py
cc4bd74 verified
raw
history blame
7.16 kB
import gradio as gr
import yfinance as yf
from prophet import Prophet
from sklearn.linear_model import LinearRegression, BayesianRidge
from sklearn.svm import SVR
from sklearn.preprocessing import MinMaxScaler
from statsmodels.tsa.arima.model import ARIMA
from xgboost import XGBRegressor
import pandas as pd
import numpy as np
from datetime import datetime
import plotly.graph_objects as go
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
def download_data(ticker, start_date='2010-01-01'):
"""
์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ํฌ๋งท์„ ์กฐ์ •ํ•˜๋Š” ํ•จ์ˆ˜
"""
data = yf.download(ticker, start=start_date)
if data.empty:
raise ValueError(f"No data returned for {ticker}")
data.reset_index(inplace=True)
if 'Adj Close' in data.columns:
data = data[['Date', 'Adj Close']]
data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
else:
raise ValueError("Expected 'Adj Close' in columns")
return data
def predict_future_prices(ticker, periods=1825):
data = download_data(ticker)
# Prophet ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
model_prophet = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
model_prophet.fit(data)
# ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„ ์ƒ์„ฑ ๋ฐ ์˜ˆ์ธก
future = model_prophet.make_future_dataframe(periods=periods, freq='D')
forecast_prophet = model_prophet.predict(future)
# Linear Regression ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
model_lr = LinearRegression()
X = pd.to_numeric(pd.Series(range(len(data))))
y = data['y'].values
model_lr.fit(X.values.reshape(-1, 1), y)
# ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„ ์ƒ์„ฑ ๋ฐ ์˜ˆ์ธก
future_dates = pd.date_range(start=data['ds'].iloc[-1], periods=periods+1, freq='D')[1:].strftime('%Y-%m-%d')
future_lr = pd.DataFrame({'ds': future_dates, 'yhat': model_lr.predict(X_future.values.reshape(-1, 1))})
future_lr['ds'] = future_lr['ds'].dt.strftime('%Y-%m-%d')
X_future = pd.to_numeric(pd.Series(range(len(data), len(data) + len(future_lr))))
future_lr['yhat'] = model_lr.predict(X_future.values.reshape(-1, 1))
# ARIMA ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
model_arima = ARIMA(data['y'], order=(1, 1, 1))
model_arima_fit = model_arima.fit()
forecast_arima = model_arima_fit.forecast(steps=periods)
future_arima = pd.DataFrame({'ds': future_dates, 'yhat': forecast_arima})
# LSTM ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data['y'].values.reshape(-1, 1))
X_train, y_train = [], []
for i in range(60, len(scaled_data)):
X_train.append(scaled_data[i-60:i, 0])
y_train.append(scaled_data[i, 0])
X_train, y_train = np.array(X_train), np.array(y_train)
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
model_lstm = Sequential()
model_lstm.add(LSTM(units=50, return_sequences=True, input_shape=(X_train.shape[1], 1)))
model_lstm.add(LSTM(units=50))
model_lstm.add(Dense(1))
model_lstm.compile(loss='mean_squared_error', optimizer='adam')
model_lstm.fit(X_train, y_train, epochs=10, batch_size=32)
pred_lstm = []
last_60_days = scaled_data[-60:]
for i in range(periods):
X_test = last_60_days.reshape(1, 60, 1)
pred = model_lstm.predict(X_test)
last_60_days = np.append(last_60_days[1:], pred)
pred_lstm.append(pred[0, 0])
pred_lstm = scaler.inverse_transform(np.array(pred_lstm).reshape(-1, 1))
future_lstm = pd.DataFrame({'ds': future_dates[:len(pred_lstm)], 'yhat': pred_lstm.flatten()})
# XGBoost ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
model_xgb = XGBRegressor(n_estimators=100, learning_rate=0.1)
model_xgb.fit(X.values.reshape(-1, 1), y)
future_xgb = pd.DataFrame({'ds': future_dates, 'yhat': model_xgb.predict(X_future.values.reshape(-1, 1))})
# SVR ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
model_svr = SVR(kernel='rbf', C=1e3, gamma=0.1)
model_svr.fit(X.values.reshape(-1, 1), y)
future_svr = pd.DataFrame({'ds': future_dates, 'yhat': model_svr.predict(X_future.values.reshape(-1, 1))})
# Bayesian Regression ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
model_bayes = BayesianRidge()
model_bayes.fit(X.values.reshape(-1, 1), y)
future_bayes = pd.DataFrame({'ds': future_dates, 'yhat': model_bayes.predict(X_future.values.reshape(-1, 1))})
# ์˜ˆ์ธก ๊ฒฐ๊ณผ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
forecast_prophet['ds'] = forecast_prophet['ds'].dt.strftime('%Y-%m-%d')
fig = go.Figure()
fig.add_trace(go.Scatter(x=forecast_prophet['ds'], y=forecast_prophet['yhat'], mode='lines', name='Prophet Forecast (Blue)'))
fig.add_trace(go.Scatter(x=future_lr['ds'], y=future_lr['yhat'], mode='lines', name='Linear Regression Forecast (Red)', line=dict(color='red')))
fig.add_trace(go.Scatter(x=future_arima['ds'], y=future_arima['yhat'], mode='lines', name='ARIMA Forecast (Green)', line=dict(color='green')))
fig.add_trace(go.Scatter(x=future_lstm['ds'], y=future_lstm['yhat'], mode='lines', name='LSTM Forecast (Orange)', line=dict(color='orange')))
fig.add_trace(go.Scatter(x=future_xgb['ds'], y=future_xgb['yhat'], mode='lines', name='XGBoost Forecast (Purple)', line=dict(color='purple')))
fig.add_trace(go.Scatter(x=future_svr['ds'], y=future_svr['yhat'], mode='lines', name='SVR Forecast (Brown)', line=dict(color='brown')))
fig.add_trace(go.Scatter(x=future_bayes['ds'], y=future_bayes['yhat'], mode='lines', name='Bayesian Regression Forecast (Pink)', line=dict(color='pink')))
fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
return fig, forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], future_lr[['ds', 'yhat']], future_arima[['ds', 'yhat']], future_lstm[['ds', 'yhat']], future_xgb[['ds', 'yhat']], future_svr[['ds', 'yhat']], future_bayes[['ds', 'yhat']]
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ๋ฐ ์‹คํ–‰
with gr.Blocks() as app:
with gr.Row():
ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
periods_input = gr.Number(value=1825, label="Forecast Period (days)")
forecast_button = gr.Button("Generate Forecast")
forecast_chart = gr.Plot(label="Forecast Chart")
forecast_data_prophet = gr.Dataframe(label="Prophet Forecast Data")
forecast_data_lr = gr.Dataframe(label="Linear Regression Forecast Data")
forecast_data_arima = gr.Dataframe(label="ARIMA Forecast Data")
forecast_data_lstm = gr.Dataframe(label="LSTM Forecast Data")
forecast_data_xgb = gr.Dataframe(label="XGBoost Forecast Data")
forecast_data_svr = gr.Dataframe(label="SVR Forecast Data")
forecast_data_bayes = gr.Dataframe(label="Bayesian Regression Forecast Data")
forecast_button.click(
fn=predict_future_prices,
inputs=[ticker_input, periods_input],
outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr, forecast_data_arima, forecast_data_lstm, forecast_data_xgb, forecast_data_svr, forecast_data_bayes]
)
app.launch()