File size: 1,892 Bytes
23c88d4
 
0d01ec2
f05802f
7185b69
 
696277a
 
 
 
 
 
1fb9abf
696277a
 
 
a8076e2
696277a
1238274
9c77d88
f05802f
fa183af
0d01ec2
240671f
0d01ec2
a8076e2
 
 
 
 
 
 
a5272ad
4db4956
 
f05802f
 
 
 
 
 
4db4956
0d01ec2
 
f05802f
bc387e2
0d01ec2
a5272ad
4db4956
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import yfinance as yf
from prophet import Prophet
import pandas as pd
import plotly.graph_objects as go

def download_data(ticker, start_date='2010-01-01'):
    data = yf.download(ticker, start=start_date)
    if data.empty:
        raise ValueError(f"No data returned for {ticker}")
    data.reset_index(inplace=True)
    if 'Adj Close' in data.columns:
        data = data[['Date', 'Adj Close']].copy()
        data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
    else:
        raise ValueError("Expected 'Adj Close' in columns")
    data['ds'] = pd.to_datetime(data['ds'])
    return data

def predict_future_prices(ticker, periods=1825):
    data = download_data(ticker)
    model = Prophet(daily_seasonality=False, weekly_seasonality=True, yearly_seasonality=True)
    model.fit(data)
    future = model.make_future_dataframe(periods=periods, freq='D')
    forecast = model.predict(future)
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
    fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast (Blue)'))
    
    # Make sure to return both the figure and the forecast data
    return fig, forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]

with gr.Blocks() as app:
    with gr.Row():
        ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
        periods_input = gr.Number(value=1825, label="Forecast Period (days)")
        forecast_button = gr.Button("Generate Forecast")
        
    forecast_chart = gr.Plot(label="Forecast Chart")
    forecast_data = gr.Dataframe(label="Forecast Data")
    
    forecast_button.click(
        fn=predict_future_prices,
        inputs=[ticker_input, periods_input],
        outputs=[forecast_chart, forecast_data]
    )

app.launch()