Spaces:
Sleeping
Sleeping
File size: 2,914 Bytes
23c88d4 0d01ec2 f05802f b277084 7185b69 fa183af 7185b69 696277a 9c77d88 696277a 1fb9abf 696277a 486a745 696277a 1238274 9c77d88 f05802f 240671f 9c77d88 fa183af 0d01ec2 240671f 9c77d88 240671f 0d01ec2 be7f66e 9c77d88 fa183af 486a745 9ff4e6c fa183af 486a745 fa183af 486a745 fa183af 486a745 fa183af a5272ad f05802f 4db4956 f05802f fa183af 4db4956 0d01ec2 f05802f fa183af 0d01ec2 a5272ad 4db4956 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import yfinance as yf
from prophet import Prophet
import pandas as pd
from datetime import datetime
import plotly.graph_objects as go
import plotly.express as px
def download_data(ticker, start_date='2010-01-01'):
""" ์ฃผ์ ๋ฐ์ดํฐ๋ฅผ ๋ค์ด๋ก๋ํ๊ณ ํฌ๋งท์ ์กฐ์ ํ๋ ํจ์ """
data = yf.download(ticker, start=start_date)
if data.empty:
raise ValueError(f"No data returned for {ticker}")
data.reset_index(inplace=True)
if 'Adj Close' in data.columns:
data = data[['Date', 'Adj Close']].copy()
data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
data['ds'] = pd.to_datetime(data['ds']) # Ensure 'ds' is datetime type
else:
raise ValueError("Expected 'Adj Close' in columns")
return data
def predict_future_prices(ticker, periods=1825):
data = download_data(ticker)
# Prophet ๋ชจ๋ธ ์์ฑ ๋ฐ ํ์ต
model = Prophet(daily_seasonality=False, weekly_seasonality=True, yearly_seasonality=True)
model.fit(data)
# ๋ฏธ๋ ๋ฐ์ดํฐ ํ๋ ์ ์์ฑ ๋ฐ ์์ธก
future = model.make_future_dataframe(periods=periods, freq='D')
forecast = model.predict(future)
# ์์ธก ๊ฒฐ๊ณผ ๊ทธ๋ํ ์์ฑ
fig_main = go.Figure()
fig_main.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
fig_main.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
# ์ฐ๊ฐ ๋ฐ ์ฃผ๊ฐ ๊ณ์ ์ฑ ๊ทธ๋ํ ์์ฑ
# ์์๋ก matplotlib ํ๋กฏ ๋ฐํ์ ๋นํ์ฑํํ๊ณ ๊ฒฐ๊ณผ๋ง ๋ฐํ
fig_seasonal = model.plot_components(forecast)
forecast['ds'] = pd.to_datetime(forecast['ds']) # Revert to datetime type to avoid AttributeError
fig_yearly = px.line(x=pd.to_datetime(forecast['ds']), y=forecast['yearly'], labels={'x': 'Date', 'y': 'Yearly Trend'})
fig_weekly = px.line(x=pd.to_datetime(forecast['ds']), y=forecast['weekly'], labels={'x': 'Date', 'y': 'Weekly Trend'})
return fig_main, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], fig_yearly, fig_weekly
# Gradio ์ธํฐํ์ด์ค ์ค์ ๋ฐ ์คํ
with gr.Blocks() as app:
with gr.Row():
ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
periods_input = gr.Number(value=1825, label="Forecast Period (days)")
forecast_button = gr.Button("Generate Forecast")
forecast_chart = gr.Plot(label="Forecast Chart")
forecast_data = gr.Dataframe(label="Forecast Data")
yearly_chart = gr.Plot(label="Yearly (Monthly) Trend Chart")
weekly_chart = gr.Plot(label="Weekly Trend Chart")
forecast_button.click(
fn=predict_future_prices,
inputs=[ticker_input, periods_input],
outputs=[forecast_chart, forecast_data, yearly_chart, weekly_chart]
)
app.launch()
|