aiqtech commited on
Commit
fa183af
·
verified ·
1 Parent(s): 9c77d88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -4,6 +4,7 @@ from prophet import Prophet
4
  import pandas as pd
5
  from datetime import datetime
6
  import plotly.graph_objects as go
 
7
 
8
  def download_data(ticker, start_date='2010-01-01'):
9
  """ 주식 데이터를 다운로드하고 포맷을 조정하는 함수 """
@@ -22,7 +23,7 @@ def predict_future_prices(ticker, periods=1825):
22
  data = download_data(ticker)
23
 
24
  # Prophet 모델 생성 및 학습
25
- model = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
26
  model.fit(data)
27
 
28
  # 미래 데이터 프레임 생성 및 예측
@@ -31,11 +32,16 @@ def predict_future_prices(ticker, periods=1825):
31
 
32
  # 예측 결과 그래프 생성
33
  forecast['ds'] = forecast['ds'].dt.strftime('%Y-%m-%d')
34
- fig = go.Figure()
35
- fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
36
- fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
37
 
38
- return fig, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
 
 
 
 
 
39
 
40
  # Gradio 인터페이스 설정 및 실행
41
  with gr.Blocks() as app:
@@ -46,11 +52,13 @@ with gr.Blocks() as app:
46
 
47
  forecast_chart = gr.Plot(label="Forecast Chart")
48
  forecast_data = gr.Dataframe(label="Forecast Data")
 
 
49
 
50
  forecast_button.click(
51
  fn=predict_future_prices,
52
  inputs=[ticker_input, periods_input],
53
- outputs=[forecast_chart, forecast_data]
54
  )
55
 
56
  app.launch()
 
4
  import pandas as pd
5
  from datetime import datetime
6
  import plotly.graph_objects as go
7
+ import plotly.express as px
8
 
9
  def download_data(ticker, start_date='2010-01-01'):
10
  """ 주식 데이터를 다운로드하고 포맷을 조정하는 함수 """
 
23
  data = download_data(ticker)
24
 
25
  # Prophet 모델 생성 및 학습
26
+ model = Prophet(daily_seasonality=False, weekly_seasonality=True, yearly_seasonality=True)
27
  model.fit(data)
28
 
29
  # 미래 데이터 프레임 생성 및 예측
 
32
 
33
  # 예측 결과 그래프 생성
34
  forecast['ds'] = forecast['ds'].dt.strftime('%Y-%m-%d')
35
+ fig_main = go.Figure()
36
+ fig_main.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
37
+ fig_main.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
38
 
39
+ # 연간 주간 계절성 그래프 생성
40
+ fig_seasonal = model.plot_components(forecast)
41
+ fig_yearly = px.line(x=pd.to_datetime(fig_seasonal[0]['ds']), y=fig_seasonal[0]['yearly'], labels={'x': 'Date', 'y': 'Yearly Trend'})
42
+ fig_weekly = px.line(x=fig_seasonal[1]['day'], y=fig_seasonal[1]['weekly'], labels={'x': 'Day of Week', 'y': 'Weekly Trend'})
43
+
44
+ return fig_main, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], fig_yearly, fig_weekly
45
 
46
  # Gradio 인터페이스 설정 및 실행
47
  with gr.Blocks() as app:
 
52
 
53
  forecast_chart = gr.Plot(label="Forecast Chart")
54
  forecast_data = gr.Dataframe(label="Forecast Data")
55
+ yearly_chart = gr.Plot(label="Yearly (Monthly) Trend Chart")
56
+ weekly_chart = gr.Plot(label="Weekly Trend Chart")
57
 
58
  forecast_button.click(
59
  fn=predict_future_prices,
60
  inputs=[ticker_input, periods_input],
61
+ outputs=[forecast_chart, forecast_data, yearly_chart, weekly_chart]
62
  )
63
 
64
  app.launch()