File size: 2,053 Bytes
23c88d4
 
0d01ec2
f05802f
b277084
7185b69
 
696277a
 
 
 
 
 
 
 
 
 
 
 
1238274
240671f
f05802f
240671f
f05802f
240671f
0d01ec2
240671f
f05802f
240671f
0d01ec2
be7f66e
9ff4e6c
 
 
f05802f
240671f
9ff4e6c
a5272ad
f05802f
4db4956
 
f05802f
 
 
 
 
 
4db4956
0d01ec2
 
f05802f
0d01ec2
 
a5272ad
4db4956
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
import yfinance as yf
from prophet import Prophet
import pandas as pd
from datetime import datetime
import plotly.graph_objects as go

def download_data(ticker, start_date='2010-01-01'):
    """ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์ ์ ˆํžˆ ํฌ๋งทํŒ…ํ•˜๋Š” ํ•จ์ˆ˜ """
    data = yf.download(ticker, start=start_date)
    if data.empty:
        raise ValueError(f"No data returned for {ticker}")
    data.reset_index(inplace=True)
    if 'Adj Close' in data.columns:
        data = data[['Date', 'Adj Close']]
        data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
    else:
        raise ValueError("Expected 'Adj Close' in columns")
    return data

def predict_future_prices(ticker, periods=1825):  # 5๋…„๊ฐ„์˜ ๋ฐ์ดํ„ฐ ์˜ˆ์ธก
    data = download_data(ticker)
    
    # Prophet ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ›ˆ๋ จ
    model = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
    model.fit(data)
    
    # ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„ ์ƒ์„ฑ ๋ฐ ์˜ˆ์ธก
    future = model.make_future_dataframe(periods=periods, freq='D')
    forecast = model.predict(future)
    
    # ds๋ฅผ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
    forecast['ds'] = forecast['ds'].dt.strftime('%Y-%m-%d')
    
    # ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋ž˜ํ”„๋กœ ํ‘œํ˜„
    fig = go.Figure(data=[go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast')])
    return fig, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ๋ฐ ์‹คํ–‰
with gr.Blocks() as app:
    with gr.Row():
        ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
        periods_input = gr.Number(value=1825, label="Forecast Period (days)")
        forecast_button = gr.Button("Generate Forecast")
        
    forecast_chart = gr.Plot(label="Forecast Chart")
    forecast_data = gr.Dataframe(label="Forecast Data")
    
    forecast_button.click(
        fn=predict_future_prices,
        inputs=[ticker_input, periods_input],
        outputs=[forecast_chart, forecast_data]
    )

app.launch()