JayLacoma commited on
Commit
804260b
·
verified ·
1 Parent(s): ad0d830

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yfinance as yf
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ import gradio as gr
6
+ import timesfm
7
+
8
+ # Function to fetch stock data, generate forecast, and create an interactive plot
9
+ def stock_forecast(ticker, start_date, end_date):
10
+ try:
11
+ # Fetch historical data
12
+ stock_data = yf.download(ticker, start=start_date, end=end_date)
13
+
14
+ # If the DataFrame has a MultiIndex for columns, drop the 'Ticker' level
15
+ if isinstance(stock_data.columns, pd.MultiIndex):
16
+ stock_data.columns = stock_data.columns.droplevel(level=1)
17
+
18
+ # Explicitly set column names
19
+ stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume']
20
+
21
+ # Reset index to have 'Date' as a column
22
+ stock_data.reset_index(inplace=True)
23
+
24
+ # Select relevant columns and rename them
25
+ df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})
26
+
27
+ # Ensure the dates are in datetime format
28
+ df['ds'] = pd.to_datetime(df['ds'])
29
+
30
+ # Add a unique identifier for the time series
31
+ df['unique_id'] = ticker
32
+
33
+ # Initialize the TimesFM model
34
+ tfm = timesfm.TimesFm(
35
+ hparams=timesfm.TimesFmHparams(
36
+ backend="pytorch",
37
+ per_core_batch_size=32,
38
+ horizon_len=30, # Predicting the next 30 days
39
+ input_patch_len=32,
40
+ output_patch_len=128,
41
+ num_layers=50,
42
+ model_dims=1280,
43
+ use_positional_embedding=False,
44
+ ),
45
+ checkpoint=timesfm.TimesFmCheckpoint(
46
+ huggingface_repo_id="google/timesfm-2.0-500m-pytorch"
47
+ ),
48
+ )
49
+
50
+ # Forecast using the prepared DataFrame
51
+ forecast_df = tfm.forecast_on_df(
52
+ inputs=df,
53
+ freq="D", # Daily frequency
54
+ value_name="y",
55
+ num_jobs=-1,
56
+ )
57
+
58
+ # Combine actual and forecasted data for plotting
59
+ combined_df = pd.concat([df, forecast_df], axis=0)
60
+
61
+ # Create an interactive plot with Plotly
62
+ fig = px.line(combined_df, x='ds', y=['y', 'timesfm'], labels={'value': 'Price', 'ds': 'Date'},
63
+ title=f'{ticker} Stock Price Forecast')
64
+
65
+ # Enhance interactivity
66
+ fig.update_layout(
67
+ legend_title_text='Type',
68
+ hovermode='x unified', # Show hover info for all series at the same x-value
69
+ xaxis=dict(
70
+ rangeselector=dict(
71
+ buttons=list([
72
+ dict(count=1, label="1m", step="month", stepmode="backward"),
73
+ dict(count=6, label="6m", step="month", stepmode="backward"),
74
+ dict(count=1, label="YTD", step="year", stepmode="todate"),
75
+ dict(count=1, label="1y", step="year", stepmode="backward"),
76
+ dict(step="all")
77
+ ])
78
+ ),
79
+ rangeslider=dict(visible=True), # Add a range slider
80
+ type="date"
81
+ ),
82
+ yaxis=dict(title="Price (USD)")
83
+ )
84
+
85
+ # Add custom hover data
86
+ fig.update_traces(
87
+ hovertemplate="<b>Date:</b> %{x}<br><b>Price:</b> %{y:.2f}<extra></extra>"
88
+ )
89
+
90
+ return fig
91
+
92
+ except Exception as e:
93
+ # Return an empty Plotly figure with an error message
94
+ error_fig = go.Figure()
95
+ error_fig.update_layout(
96
+ title=f"Error: {str(e)}",
97
+ xaxis_title="Date",
98
+ yaxis_title="Price",
99
+ annotations=[dict(text="No data available", x=0.5, y=0.5, showarrow=False)]
100
+ )
101
+ return error_fig
102
+
103
+ # Create Gradio interface with an "Enter" button
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("# Stock Price Forecast App")
106
+ gr.Markdown("Enter a stock ticker, start date, and end date to visualize historical and forecasted stock prices.")
107
+
108
+ with gr.Row():
109
+ ticker_input = gr.Textbox(label="Enter Stock Ticker", value="AAPL")
110
+ start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2022-01-01")
111
+ end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2025-01-01")
112
+
113
+ submit_button = gr.Button("Enter")
114
+ plot_output = gr.Plot()
115
+
116
+ # Link the button to the function
117
+ submit_button.click(
118
+ stock_forecast,
119
+ inputs=[ticker_input, start_date_input, end_date_input],
120
+ outputs=plot_output
121
+ )
122
+
123
+ # Launch the Gradio app
124
+ demo.launch()