File size: 2,914 Bytes
23c88d4
 
0d01ec2
f05802f
b277084
7185b69
fa183af
7185b69
696277a
9c77d88
696277a
 
 
 
 
1fb9abf
696277a
486a745
696277a
 
 
1238274
9c77d88
f05802f
240671f
9c77d88
fa183af
0d01ec2
240671f
9c77d88
240671f
0d01ec2
be7f66e
9c77d88
fa183af
 
486a745
9ff4e6c
fa183af
486a745
fa183af
486a745
fa183af
486a745
 
 
fa183af
a5272ad
f05802f
4db4956
 
f05802f
 
 
 
 
 
fa183af
 
4db4956
0d01ec2
 
f05802f
fa183af
0d01ec2
a5272ad
4db4956
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import yfinance as yf
from prophet import Prophet
import pandas as pd
from datetime import datetime
import plotly.graph_objects as go
import plotly.express as px

def download_data(ticker, start_date='2010-01-01'):
    """ ์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ํฌ๋งท์„ ์กฐ์ •ํ•˜๋Š” ํ•จ์ˆ˜ """
    data = yf.download(ticker, start=start_date)
    if data.empty:
        raise ValueError(f"No data returned for {ticker}")
    data.reset_index(inplace=True)
    if 'Adj Close' in data.columns:
        data = data[['Date', 'Adj Close']].copy()
        data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
        data['ds'] = pd.to_datetime(data['ds'])  # Ensure 'ds' is datetime type
    else:
        raise ValueError("Expected 'Adj Close' in columns")
    return data

def predict_future_prices(ticker, periods=1825):
    data = download_data(ticker)
    
    # Prophet ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
    model = Prophet(daily_seasonality=False, weekly_seasonality=True, yearly_seasonality=True)
    model.fit(data)
    
    # ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„ ์ƒ์„ฑ ๋ฐ ์˜ˆ์ธก
    future = model.make_future_dataframe(periods=periods, freq='D')
    forecast = model.predict(future)
    
    # ์˜ˆ์ธก ๊ฒฐ๊ณผ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
    fig_main = go.Figure()
    fig_main.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
    fig_main.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
    
    # ์—ฐ๊ฐ„ ๋ฐ ์ฃผ๊ฐ„ ๊ณ„์ ˆ์„ฑ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
    # ์ž„์‹œ๋กœ matplotlib ํ”Œ๋กฏ ๋ฐ˜ํ™˜์„ ๋น„ํ™œ์„ฑํ™”ํ•˜๊ณ  ๊ฒฐ๊ณผ๋งŒ ๋ฐ˜ํ™˜
    fig_seasonal = model.plot_components(forecast)
    forecast['ds'] = pd.to_datetime(forecast['ds'])  # Revert to datetime type to avoid AttributeError
    
    fig_yearly = px.line(x=pd.to_datetime(forecast['ds']), y=forecast['yearly'], labels={'x': 'Date', 'y': 'Yearly Trend'})
    fig_weekly = px.line(x=pd.to_datetime(forecast['ds']), y=forecast['weekly'], labels={'x': 'Date', 'y': 'Weekly Trend'})

    return fig_main, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], fig_yearly, fig_weekly

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ๋ฐ ์‹คํ–‰
with gr.Blocks() as app:
    with gr.Row():
        ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
        periods_input = gr.Number(value=1825, label="Forecast Period (days)")
        forecast_button = gr.Button("Generate Forecast")
        
    forecast_chart = gr.Plot(label="Forecast Chart")
    forecast_data = gr.Dataframe(label="Forecast Data")
    yearly_chart = gr.Plot(label="Yearly (Monthly) Trend Chart")
    weekly_chart = gr.Plot(label="Weekly Trend Chart")
    
    forecast_button.click(
        fn=predict_future_prices,
        inputs=[ticker_input, periods_input],
        outputs=[forecast_chart, forecast_data, yearly_chart, weekly_chart]
    )

app.launch()