aiqtech commited on
Commit
a8076e2
·
verified ·
1 Parent(s): bc387e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -17
app.py CHANGED
@@ -2,11 +2,7 @@ import gradio as gr
2
  import yfinance as yf
3
  from prophet import Prophet
4
  import pandas as pd
5
- from datetime import datetime
6
  import plotly.graph_objects as go
7
- import plotly.express as px
8
- import json
9
- import numpy as np
10
 
11
  def download_data(ticker, start_date='2010-01-01'):
12
  data = yf.download(ticker, start=start_date)
@@ -16,29 +12,24 @@ def download_data(ticker, start_date='2010-01-01'):
16
  if 'Adj Close' in data.columns:
17
  data = data[['Date', 'Adj Close']].copy()
18
  data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
19
- data['ds'] = pd.to_datetime(data['ds'])
20
  else:
21
  raise ValueError("Expected 'Adj Close' in columns")
 
22
  return data
23
 
24
- def json_serial(obj):
25
- """JSON serializer for objects not serializable by default json code"""
26
- if isinstance(obj, (datetime, np.datetime64)):
27
- return obj.isoformat()
28
- raise TypeError("Type not serializable")
29
-
30
  def predict_future_prices(ticker, periods=1825):
31
  data = download_data(ticker)
32
  model = Prophet(daily_seasonality=False, weekly_seasonality=True, yearly_seasonality=True)
33
  model.fit(data)
34
  future = model.make_future_dataframe(periods=periods, freq='D')
35
  forecast = model.predict(future)
36
- forecast['ds'] = forecast['ds'].apply(pd.to_datetime)
37
- fig_main = go.Figure()
38
- fig_main.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
39
- fig_main.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
40
- fig_main.update_layout(title="Forecast vs Actual Stock Prices")
41
- return json.dumps({"figure": fig_main.to_json(), "forecast": forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].to_dict('records')}, default=json_serial)
 
42
 
43
  with gr.Blocks() as app:
44
  with gr.Row():
 
2
  import yfinance as yf
3
  from prophet import Prophet
4
  import pandas as pd
 
5
  import plotly.graph_objects as go
 
 
 
6
 
7
  def download_data(ticker, start_date='2010-01-01'):
8
  data = yf.download(ticker, start=start_date)
 
12
  if 'Adj Close' in data.columns:
13
  data = data[['Date', 'Adj Close']].copy()
14
  data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
 
15
  else:
16
  raise ValueError("Expected 'Adj Close' in columns")
17
+ data['ds'] = pd.to_datetime(data['ds'])
18
  return data
19
 
 
 
 
 
 
 
20
  def predict_future_prices(ticker, periods=1825):
21
  data = download_data(ticker)
22
  model = Prophet(daily_seasonality=False, weekly_seasonality=True, yearly_seasonality=True)
23
  model.fit(data)
24
  future = model.make_future_dataframe(periods=periods, freq='D')
25
  forecast = model.predict(future)
26
+
27
+ fig = go.Figure()
28
+ fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
29
+ fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
30
+
31
+ # Make sure to return both the figure and the forecast data
32
+ return fig, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
33
 
34
  with gr.Blocks() as app:
35
  with gr.Row():