aiqtech commited on
Commit
a7925b2
·
verified ·
1 Parent(s): a8076e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -2,35 +2,42 @@ import gradio as gr
2
  import yfinance as yf
3
  from prophet import Prophet
4
  import pandas as pd
 
5
  import plotly.graph_objects as go
6
 
7
  def download_data(ticker, start_date='2010-01-01'):
 
8
  data = yf.download(ticker, start=start_date)
9
  if data.empty:
10
  raise ValueError(f"No data returned for {ticker}")
11
  data.reset_index(inplace=True)
12
  if 'Adj Close' in data.columns:
13
- data = data[['Date', 'Adj Close']].copy()
14
  data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
15
  else:
16
  raise ValueError("Expected 'Adj Close' in columns")
17
- data['ds'] = pd.to_datetime(data['ds'])
18
  return data
19
 
20
  def predict_future_prices(ticker, periods=1825):
21
  data = download_data(ticker)
22
- model = Prophet(daily_seasonality=False, weekly_seasonality=True, yearly_seasonality=True)
 
 
23
  model.fit(data)
 
 
24
  future = model.make_future_dataframe(periods=periods, freq='D')
25
  forecast = model.predict(future)
26
 
 
 
27
  fig = go.Figure()
28
- fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
29
  fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
 
30
 
31
- # Make sure to return both the figure and the forecast data
32
  return fig, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
33
 
 
34
  with gr.Blocks() as app:
35
  with gr.Row():
36
  ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
 
2
  import yfinance as yf
3
  from prophet import Prophet
4
  import pandas as pd
5
+ from datetime import datetime
6
  import plotly.graph_objects as go
7
 
8
  def download_data(ticker, start_date='2010-01-01'):
9
+ """ 주식 데이터를 다운로드하고 포맷을 조정하는 함수 """
10
  data = yf.download(ticker, start=start_date)
11
  if data.empty:
12
  raise ValueError(f"No data returned for {ticker}")
13
  data.reset_index(inplace=True)
14
  if 'Adj Close' in data.columns:
15
+ data = data[['Date', 'Adj Close']]
16
  data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
17
  else:
18
  raise ValueError("Expected 'Adj Close' in columns")
 
19
  return data
20
 
21
  def predict_future_prices(ticker, periods=1825):
22
  data = download_data(ticker)
23
+
24
+ # Prophet 모델 생성 및 학습
25
+ model = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
26
  model.fit(data)
27
+
28
+ # 미래 데이터 프레임 생성 및 예측
29
  future = model.make_future_dataframe(periods=periods, freq='D')
30
  forecast = model.predict(future)
31
 
32
+ # 예측 결과 그래프 생성
33
+ forecast['ds'] = forecast['ds'].dt.strftime('%Y-%m-%d')
34
  fig = go.Figure()
 
35
  fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
36
+ fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
37
 
 
38
  return fig, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
39
 
40
+ # Gradio 인터페이스 설정 및 실행
41
  with gr.Blocks() as app:
42
  with gr.Row():
43
  ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")