aiqtech commited on
Commit
f05802f
·
verified ·
1 Parent(s): f448ee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -20
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import gradio as gr
2
  import yfinance as yf
3
  from prophet import Prophet
4
- import plotly.express as px
5
  from datetime import datetime
6
  import plotly.graph_objects as go
7
 
8
-
9
  def download_data(ticker, start_date='2010-01-01'):
10
  """ 데이터를 다운로드하고 적절히 포맷팅하는 함수 """
11
  data = yf.download(ticker, start=start_date)
@@ -20,38 +19,33 @@ def download_data(ticker, start_date='2010-01-01'):
20
  return data
21
 
22
  def predict_future_prices(ticker, periods=1825): # 5년간의 데이터 예측
23
- data = yf.download(ticker, start="2010-01-01")['Adj Close'].reset_index()
24
- data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
25
 
 
26
  model = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
27
  model.fit(data)
28
 
 
29
  future = model.make_future_dataframe(periods=periods, freq='D')
30
  forecast = model.predict(future)
31
 
32
- # Timestamp를 문자열로 변환
33
- forecast['ds'] = forecast['ds'].dt.strftime('%Y-%m-%d')
34
-
35
  fig = go.Figure(data=[go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast')])
36
- fig.update_layout(title=f'5-Year Future Price Forecast for {ticker}', xaxis_title='Date', yaxis_title='Predicted Price')
37
-
38
- return fig, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].to_dict(orient='records')
39
 
40
- # 이제 이 함수를 Gradio 인터페이스와 연결하여 사용자가 시각화를 볼 수 있게 합니다.
41
-
42
-
43
- # Gradio 인터페이스 설정
44
  with gr.Blocks() as app:
45
  with gr.Row():
46
- ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker")
47
- forecast_button = gr.Button("Generate 5-Year Forecast")
48
-
49
- forecast_chart = gr.Plot(label="Future Price Forecast")
50
- forecast_data = gr.Dataframe(label="Forecast Data Table")
 
51
 
52
  forecast_button.click(
53
  fn=predict_future_prices,
54
- inputs=ticker_input,
55
  outputs=[forecast_chart, forecast_data]
56
  )
57
 
 
1
  import gradio as gr
2
  import yfinance as yf
3
  from prophet import Prophet
4
+ import pandas as pd
5
  from datetime import datetime
6
  import plotly.graph_objects as go
7
 
 
8
  def download_data(ticker, start_date='2010-01-01'):
9
  """ 데이터를 다운로드하고 적절히 포맷팅하는 함수 """
10
  data = yf.download(ticker, start=start_date)
 
19
  return data
20
 
21
  def predict_future_prices(ticker, periods=1825): # 5년간의 데이터 예측
22
+ data = download_data(ticker)
 
23
 
24
+ # Prophet 모델 생성 및 훈련
25
  model = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
26
  model.fit(data)
27
 
28
+ # 미래 데이터프레임 생성 및 예측
29
  future = model.make_future_dataframe(periods=periods, freq='D')
30
  forecast = model.predict(future)
31
 
32
+ # 예측 결과를 그래프로 표현
 
 
33
  fig = go.Figure(data=[go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast')])
34
+ return fig, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']] # Pydantic과 같은 엄격한 형식 검사 없이 DataFrame으로 직접 전달
 
 
35
 
36
+ # Gradio 인터페이스 설정 실행
 
 
 
37
  with gr.Blocks() as app:
38
  with gr.Row():
39
+ ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
40
+ periods_input = gr.Number(value=1825, label="Forecast Period (days)")
41
+ forecast_button = gr.Button("Generate Forecast")
42
+
43
+ forecast_chart = gr.Plot(label="Forecast Chart")
44
+ forecast_data = gr.Dataframe(label="Forecast Data")
45
 
46
  forecast_button.click(
47
  fn=predict_future_prices,
48
+ inputs=[ticker_input, periods_input],
49
  outputs=[forecast_chart, forecast_data]
50
  )
51