Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yfinance as yf
|
2 |
+
import pandas as pd
|
3 |
+
import plotly.express as px
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
import gradio as gr
|
6 |
+
import timesfm
|
7 |
+
|
8 |
+
# Function to fetch stock data, generate forecast, and create an interactive plot
|
9 |
+
def stock_forecast(ticker, start_date, end_date):
|
10 |
+
try:
|
11 |
+
# Fetch historical data
|
12 |
+
stock_data = yf.download(ticker, start=start_date, end=end_date)
|
13 |
+
|
14 |
+
# If the DataFrame has a MultiIndex for columns, drop the 'Ticker' level
|
15 |
+
if isinstance(stock_data.columns, pd.MultiIndex):
|
16 |
+
stock_data.columns = stock_data.columns.droplevel(level=1)
|
17 |
+
|
18 |
+
# Explicitly set column names
|
19 |
+
stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume']
|
20 |
+
|
21 |
+
# Reset index to have 'Date' as a column
|
22 |
+
stock_data.reset_index(inplace=True)
|
23 |
+
|
24 |
+
# Select relevant columns and rename them
|
25 |
+
df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})
|
26 |
+
|
27 |
+
# Ensure the dates are in datetime format
|
28 |
+
df['ds'] = pd.to_datetime(df['ds'])
|
29 |
+
|
30 |
+
# Add a unique identifier for the time series
|
31 |
+
df['unique_id'] = ticker
|
32 |
+
|
33 |
+
# Initialize the TimesFM model
|
34 |
+
tfm = timesfm.TimesFm(
|
35 |
+
hparams=timesfm.TimesFmHparams(
|
36 |
+
backend="pytorch",
|
37 |
+
per_core_batch_size=32,
|
38 |
+
horizon_len=30, # Predicting the next 30 days
|
39 |
+
input_patch_len=32,
|
40 |
+
output_patch_len=128,
|
41 |
+
num_layers=50,
|
42 |
+
model_dims=1280,
|
43 |
+
use_positional_embedding=False,
|
44 |
+
),
|
45 |
+
checkpoint=timesfm.TimesFmCheckpoint(
|
46 |
+
huggingface_repo_id="google/timesfm-2.0-500m-pytorch"
|
47 |
+
),
|
48 |
+
)
|
49 |
+
|
50 |
+
# Forecast using the prepared DataFrame
|
51 |
+
forecast_df = tfm.forecast_on_df(
|
52 |
+
inputs=df,
|
53 |
+
freq="D", # Daily frequency
|
54 |
+
value_name="y",
|
55 |
+
num_jobs=-1,
|
56 |
+
)
|
57 |
+
|
58 |
+
# Combine actual and forecasted data for plotting
|
59 |
+
combined_df = pd.concat([df, forecast_df], axis=0)
|
60 |
+
|
61 |
+
# Create an interactive plot with Plotly
|
62 |
+
fig = px.line(combined_df, x='ds', y=['y', 'timesfm'], labels={'value': 'Price', 'ds': 'Date'},
|
63 |
+
title=f'{ticker} Stock Price Forecast')
|
64 |
+
|
65 |
+
# Enhance interactivity
|
66 |
+
fig.update_layout(
|
67 |
+
legend_title_text='Type',
|
68 |
+
hovermode='x unified', # Show hover info for all series at the same x-value
|
69 |
+
xaxis=dict(
|
70 |
+
rangeselector=dict(
|
71 |
+
buttons=list([
|
72 |
+
dict(count=1, label="1m", step="month", stepmode="backward"),
|
73 |
+
dict(count=6, label="6m", step="month", stepmode="backward"),
|
74 |
+
dict(count=1, label="YTD", step="year", stepmode="todate"),
|
75 |
+
dict(count=1, label="1y", step="year", stepmode="backward"),
|
76 |
+
dict(step="all")
|
77 |
+
])
|
78 |
+
),
|
79 |
+
rangeslider=dict(visible=True), # Add a range slider
|
80 |
+
type="date"
|
81 |
+
),
|
82 |
+
yaxis=dict(title="Price (USD)")
|
83 |
+
)
|
84 |
+
|
85 |
+
# Add custom hover data
|
86 |
+
fig.update_traces(
|
87 |
+
hovertemplate="<b>Date:</b> %{x}<br><b>Price:</b> %{y:.2f}<extra></extra>"
|
88 |
+
)
|
89 |
+
|
90 |
+
return fig
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
# Return an empty Plotly figure with an error message
|
94 |
+
error_fig = go.Figure()
|
95 |
+
error_fig.update_layout(
|
96 |
+
title=f"Error: {str(e)}",
|
97 |
+
xaxis_title="Date",
|
98 |
+
yaxis_title="Price",
|
99 |
+
annotations=[dict(text="No data available", x=0.5, y=0.5, showarrow=False)]
|
100 |
+
)
|
101 |
+
return error_fig
|
102 |
+
|
103 |
+
# Create Gradio interface with an "Enter" button
|
104 |
+
with gr.Blocks() as demo:
|
105 |
+
gr.Markdown("# Stock Price Forecast App")
|
106 |
+
gr.Markdown("Enter a stock ticker, start date, and end date to visualize historical and forecasted stock prices.")
|
107 |
+
|
108 |
+
with gr.Row():
|
109 |
+
ticker_input = gr.Textbox(label="Enter Stock Ticker", value="AAPL")
|
110 |
+
start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2022-01-01")
|
111 |
+
end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2025-01-01")
|
112 |
+
|
113 |
+
submit_button = gr.Button("Enter")
|
114 |
+
plot_output = gr.Plot()
|
115 |
+
|
116 |
+
# Link the button to the function
|
117 |
+
submit_button.click(
|
118 |
+
stock_forecast,
|
119 |
+
inputs=[ticker_input, start_date_input, end_date_input],
|
120 |
+
outputs=plot_output
|
121 |
+
)
|
122 |
+
|
123 |
+
# Launch the Gradio app
|
124 |
+
demo.launch()
|