aiqtech commited on
Commit
9c77d88
·
verified ·
1 Parent(s): 9ff4e6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -6,7 +6,7 @@ from datetime import datetime
6
  import plotly.graph_objects as go
7
 
8
  def download_data(ticker, start_date='2010-01-01'):
9
- """ 데이터를 다운로드하고 적절히 포맷팅하는 함수 """
10
  data = yf.download(ticker, start=start_date)
11
  if data.empty:
12
  raise ValueError(f"No data returned for {ticker}")
@@ -18,22 +18,23 @@ def download_data(ticker, start_date='2010-01-01'):
18
  raise ValueError("Expected 'Adj Close' in columns")
19
  return data
20
 
21
- def predict_future_prices(ticker, periods=1825): # 5년간의 데이터 예측
22
  data = download_data(ticker)
23
 
24
- # Prophet 모델 생성 및 훈련
25
  model = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
26
  model.fit(data)
27
 
28
- # 미래 데이터프레임 생성 및 예측
29
  future = model.make_future_dataframe(periods=periods, freq='D')
30
  forecast = model.predict(future)
31
 
32
- # ds를 문자열로 변환
33
  forecast['ds'] = forecast['ds'].dt.strftime('%Y-%m-%d')
 
 
 
34
 
35
- # 예측 결과를 그래프로 표현
36
- fig = go.Figure(data=[go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast')])
37
  return fig, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
38
 
39
  # Gradio 인터페이스 설정 및 실행
 
6
  import plotly.graph_objects as go
7
 
8
  def download_data(ticker, start_date='2010-01-01'):
9
+ """ 주식 데이터를 다운로드하고 포맷을 조정하는 함수 """
10
  data = yf.download(ticker, start=start_date)
11
  if data.empty:
12
  raise ValueError(f"No data returned for {ticker}")
 
18
  raise ValueError("Expected 'Adj Close' in columns")
19
  return data
20
 
21
+ def predict_future_prices(ticker, periods=1825):
22
  data = download_data(ticker)
23
 
24
+ # Prophet 모델 생성 및 학습
25
  model = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
26
  model.fit(data)
27
 
28
+ # 미래 데이터 프레임 생성 및 예측
29
  future = model.make_future_dataframe(periods=periods, freq='D')
30
  forecast = model.predict(future)
31
 
32
+ # 예측 결과 그래프 생성
33
  forecast['ds'] = forecast['ds'].dt.strftime('%Y-%m-%d')
34
+ fig = go.Figure()
35
+ fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
36
+ fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
37
 
 
 
38
  return fig, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
39
 
40
  # Gradio 인터페이스 설정 및 실행