aiqtech commited on
Commit
bc387e2
·
verified ·
1 Parent(s): 486a745

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -22
app.py CHANGED
@@ -5,9 +5,10 @@ import pandas as pd
5
  from datetime import datetime
6
  import plotly.graph_objects as go
7
  import plotly.express as px
 
 
8
 
9
  def download_data(ticker, start_date='2010-01-01'):
10
- """ 주식 데이터를 다운로드하고 포맷을 조정하는 함수 """
11
  data = yf.download(ticker, start=start_date)
12
  if data.empty:
13
  raise ValueError(f"No data returned for {ticker}")
@@ -15,38 +16,30 @@ def download_data(ticker, start_date='2010-01-01'):
15
  if 'Adj Close' in data.columns:
16
  data = data[['Date', 'Adj Close']].copy()
17
  data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
18
- data['ds'] = pd.to_datetime(data['ds']) # Ensure 'ds' is datetime type
19
  else:
20
  raise ValueError("Expected 'Adj Close' in columns")
21
  return data
22
 
 
 
 
 
 
 
23
  def predict_future_prices(ticker, periods=1825):
24
  data = download_data(ticker)
25
-
26
- # Prophet 모델 생성 및 학습
27
  model = Prophet(daily_seasonality=False, weekly_seasonality=True, yearly_seasonality=True)
28
  model.fit(data)
29
-
30
- # 미래 데이터 프레임 생성 및 예측
31
  future = model.make_future_dataframe(periods=periods, freq='D')
32
  forecast = model.predict(future)
33
-
34
- # 예측 결과 그래프 생성
35
  fig_main = go.Figure()
36
  fig_main.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
37
  fig_main.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
38
-
39
- # 연간 주간 계절성 그래프 생성
40
- # 임시로 matplotlib 플롯 반환을 비활성화하고 결과만 반환
41
- fig_seasonal = model.plot_components(forecast)
42
- forecast['ds'] = pd.to_datetime(forecast['ds']) # Revert to datetime type to avoid AttributeError
43
-
44
- fig_yearly = px.line(x=pd.to_datetime(forecast['ds']), y=forecast['yearly'], labels={'x': 'Date', 'y': 'Yearly Trend'})
45
- fig_weekly = px.line(x=pd.to_datetime(forecast['ds']), y=forecast['weekly'], labels={'x': 'Date', 'y': 'Weekly Trend'})
46
-
47
- return fig_main, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], fig_yearly, fig_weekly
48
 
49
- # Gradio 인터페이스 설정 및 실행
50
  with gr.Blocks() as app:
51
  with gr.Row():
52
  ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
@@ -55,13 +48,11 @@ with gr.Blocks() as app:
55
 
56
  forecast_chart = gr.Plot(label="Forecast Chart")
57
  forecast_data = gr.Dataframe(label="Forecast Data")
58
- yearly_chart = gr.Plot(label="Yearly (Monthly) Trend Chart")
59
- weekly_chart = gr.Plot(label="Weekly Trend Chart")
60
 
61
  forecast_button.click(
62
  fn=predict_future_prices,
63
  inputs=[ticker_input, periods_input],
64
- outputs=[forecast_chart, forecast_data, yearly_chart, weekly_chart]
65
  )
66
 
67
  app.launch()
 
5
  from datetime import datetime
6
  import plotly.graph_objects as go
7
  import plotly.express as px
8
+ import json
9
+ import numpy as np
10
 
11
  def download_data(ticker, start_date='2010-01-01'):
 
12
  data = yf.download(ticker, start=start_date)
13
  if data.empty:
14
  raise ValueError(f"No data returned for {ticker}")
 
16
  if 'Adj Close' in data.columns:
17
  data = data[['Date', 'Adj Close']].copy()
18
  data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
19
+ data['ds'] = pd.to_datetime(data['ds'])
20
  else:
21
  raise ValueError("Expected 'Adj Close' in columns")
22
  return data
23
 
24
+ def json_serial(obj):
25
+ """JSON serializer for objects not serializable by default json code"""
26
+ if isinstance(obj, (datetime, np.datetime64)):
27
+ return obj.isoformat()
28
+ raise TypeError("Type not serializable")
29
+
30
  def predict_future_prices(ticker, periods=1825):
31
  data = download_data(ticker)
 
 
32
  model = Prophet(daily_seasonality=False, weekly_seasonality=True, yearly_seasonality=True)
33
  model.fit(data)
 
 
34
  future = model.make_future_dataframe(periods=periods, freq='D')
35
  forecast = model.predict(future)
36
+ forecast['ds'] = forecast['ds'].apply(pd.to_datetime)
 
37
  fig_main = go.Figure()
38
  fig_main.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
39
  fig_main.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
40
+ fig_main.update_layout(title="Forecast vs Actual Stock Prices")
41
+ return json.dumps({"figure": fig_main.to_json(), "forecast": forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].to_dict('records')}, default=json_serial)
 
 
 
 
 
 
 
 
42
 
 
43
  with gr.Blocks() as app:
44
  with gr.Row():
45
  ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
 
48
 
49
  forecast_chart = gr.Plot(label="Forecast Chart")
50
  forecast_data = gr.Dataframe(label="Forecast Data")
 
 
51
 
52
  forecast_button.click(
53
  fn=predict_future_prices,
54
  inputs=[ticker_input, periods_input],
55
+ outputs=[forecast_chart, forecast_data]
56
  )
57
 
58
  app.launch()