stockport9 / app.py
aiqtech's picture
Update app.py
486a745 verified
raw
history blame
2.91 kB
import gradio as gr
import yfinance as yf
from prophet import Prophet
import pandas as pd
from datetime import datetime
import plotly.graph_objects as go
import plotly.express as px
def download_data(ticker, start_date='2010-01-01'):
""" ์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ํฌ๋งท์„ ์กฐ์ •ํ•˜๋Š” ํ•จ์ˆ˜ """
data = yf.download(ticker, start=start_date)
if data.empty:
raise ValueError(f"No data returned for {ticker}")
data.reset_index(inplace=True)
if 'Adj Close' in data.columns:
data = data[['Date', 'Adj Close']].copy()
data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
data['ds'] = pd.to_datetime(data['ds']) # Ensure 'ds' is datetime type
else:
raise ValueError("Expected 'Adj Close' in columns")
return data
def predict_future_prices(ticker, periods=1825):
data = download_data(ticker)
# Prophet ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
model = Prophet(daily_seasonality=False, weekly_seasonality=True, yearly_seasonality=True)
model.fit(data)
# ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„ ์ƒ์„ฑ ๋ฐ ์˜ˆ์ธก
future = model.make_future_dataframe(periods=periods, freq='D')
forecast = model.predict(future)
# ์˜ˆ์ธก ๊ฒฐ๊ณผ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
fig_main = go.Figure()
fig_main.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
fig_main.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
# ์—ฐ๊ฐ„ ๋ฐ ์ฃผ๊ฐ„ ๊ณ„์ ˆ์„ฑ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
# ์ž„์‹œ๋กœ matplotlib ํ”Œ๋กฏ ๋ฐ˜ํ™˜์„ ๋น„ํ™œ์„ฑํ™”ํ•˜๊ณ  ๊ฒฐ๊ณผ๋งŒ ๋ฐ˜ํ™˜
fig_seasonal = model.plot_components(forecast)
forecast['ds'] = pd.to_datetime(forecast['ds']) # Revert to datetime type to avoid AttributeError
fig_yearly = px.line(x=pd.to_datetime(forecast['ds']), y=forecast['yearly'], labels={'x': 'Date', 'y': 'Yearly Trend'})
fig_weekly = px.line(x=pd.to_datetime(forecast['ds']), y=forecast['weekly'], labels={'x': 'Date', 'y': 'Weekly Trend'})
return fig_main, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], fig_yearly, fig_weekly
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ๋ฐ ์‹คํ–‰
with gr.Blocks() as app:
with gr.Row():
ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
periods_input = gr.Number(value=1825, label="Forecast Period (days)")
forecast_button = gr.Button("Generate Forecast")
forecast_chart = gr.Plot(label="Forecast Chart")
forecast_data = gr.Dataframe(label="Forecast Data")
yearly_chart = gr.Plot(label="Yearly (Monthly) Trend Chart")
weekly_chart = gr.Plot(label="Weekly Trend Chart")
forecast_button.click(
fn=predict_future_prices,
inputs=[ticker_input, periods_input],
outputs=[forecast_chart, forecast_data, yearly_chart, weekly_chart]
)
app.launch()