Spaces:
Running
on
T4
Running
on
T4
Nikita
commited on
Commit
·
76a84b9
1
Parent(s):
3272ebc
waiting for the review
Browse files- app.py +286 -0
- static/nxai_logo.png +0 -0
- static/tirex.jpeg +0 -0
- stocks_data.csv +11 -0
- stocks_data_forecast.pt +0 -0
- stocks_data_noindex.csv +11 -0
- test.ipynb +1682 -0
- tirex +1 -0
app.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
# ----------------------------
|
10 |
+
# Helper functions (logic mostly unchanged)
|
11 |
+
# ----------------------------
|
12 |
+
|
13 |
+
torch.manual_seed(42)
|
14 |
+
_forecast_tensor = torch.load("stocks_data_forecast.pt") # shape = (n_series, pred_len, n_q)
|
15 |
+
|
16 |
+
|
17 |
+
def model_forecast(input_data):
|
18 |
+
return _forecast_tensor
|
19 |
+
|
20 |
+
|
21 |
+
def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name):
|
22 |
+
# Create an interactive Plotly figure
|
23 |
+
fig = go.Figure()
|
24 |
+
x_hist = list(range(len(timeseries)))
|
25 |
+
# Historical data trace
|
26 |
+
fig.add_trace(go.Scatter(
|
27 |
+
x=x_hist,
|
28 |
+
y=timeseries,
|
29 |
+
mode='lines+markers',
|
30 |
+
name=f"{timeseries_name} - Given Data",
|
31 |
+
line=dict(width=2),
|
32 |
+
))
|
33 |
+
|
34 |
+
# Prediction data traces for each quantile
|
35 |
+
x_pred = list(range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions)))
|
36 |
+
for i in range(quantile_predictions.shape[1]):
|
37 |
+
fig.add_trace(go.Scatter(
|
38 |
+
x=x_pred,
|
39 |
+
y=quantile_predictions[:, i],
|
40 |
+
mode='lines',
|
41 |
+
name=f"{timeseries_name} - Quantile {i+1}",
|
42 |
+
opacity=0.8,
|
43 |
+
))
|
44 |
+
|
45 |
+
fig.update_layout(
|
46 |
+
title=dict(text=f"Timeseries: {timeseries_name}", x=0.5, font=dict(size=16, family="Arial", color="#000")),
|
47 |
+
xaxis_title="Time",
|
48 |
+
yaxis_title="Value",
|
49 |
+
hovermode='x unified'
|
50 |
+
)
|
51 |
+
return fig
|
52 |
+
|
53 |
+
|
54 |
+
def load_table(file_path):
|
55 |
+
ext = file_path.split(".")[-1].lower()
|
56 |
+
if ext == "csv":
|
57 |
+
return pd.read_csv(file_path)
|
58 |
+
elif ext in ("xls", "xlsx"):
|
59 |
+
return pd.read_excel(file_path)
|
60 |
+
elif ext == "parquet":
|
61 |
+
return pd.read_parquet(file_path)
|
62 |
+
else:
|
63 |
+
raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")
|
64 |
+
|
65 |
+
|
66 |
+
def extract_names_and_update(file, preset_filename):
|
67 |
+
try:
|
68 |
+
if file is not None:
|
69 |
+
df = load_table(file.name)
|
70 |
+
else:
|
71 |
+
if not preset_filename:
|
72 |
+
return gr.update(choices=[], value=[]), []
|
73 |
+
df = load_table(preset_filename)
|
74 |
+
|
75 |
+
if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
|
76 |
+
names = df.iloc[:, 0].tolist()
|
77 |
+
else:
|
78 |
+
names = [f"Series {i}" for i in range(len(df))]
|
79 |
+
return gr.update(choices=names, value=names), names
|
80 |
+
except Exception:
|
81 |
+
return gr.update(choices=[], value=[]), []
|
82 |
+
|
83 |
+
|
84 |
+
def filter_names(search_term, all_names):
|
85 |
+
if not all_names:
|
86 |
+
return gr.update(choices=[], value=[])
|
87 |
+
if not search_term:
|
88 |
+
return gr.update(choices=all_names, value=all_names)
|
89 |
+
lower = search_term.lower()
|
90 |
+
filtered = [n for n in all_names if lower in str(n).lower()]
|
91 |
+
return gr.update(choices=filtered, value=filtered)
|
92 |
+
|
93 |
+
|
94 |
+
def check_all(names_list):
|
95 |
+
return gr.update(value=names_list)
|
96 |
+
|
97 |
+
|
98 |
+
def uncheck_all(_):
|
99 |
+
return gr.update(value=[])
|
100 |
+
|
101 |
+
|
102 |
+
def display_filtered_forecast(file, preset_filename, selected_names):
|
103 |
+
try:
|
104 |
+
# If no file uploaded and no valid preset chosen, return early
|
105 |
+
if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"):
|
106 |
+
return None, "No file selected."
|
107 |
+
|
108 |
+
# Load data
|
109 |
+
if file is not None:
|
110 |
+
df = load_table(file.name)
|
111 |
+
else:
|
112 |
+
df = load_table(preset_filename)
|
113 |
+
|
114 |
+
if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
|
115 |
+
all_names = df.iloc[:, 0].tolist()
|
116 |
+
data_only = df.iloc[:, 1:].astype(float)
|
117 |
+
else:
|
118 |
+
all_names = [f"Series {i}" for i in range(len(df))]
|
119 |
+
data_only = df.astype(float)
|
120 |
+
|
121 |
+
mask = [name in selected_names for name in all_names]
|
122 |
+
if not any(mask):
|
123 |
+
return None, "No timeseries chosen to plot."
|
124 |
+
|
125 |
+
filtered_data = data_only.iloc[mask, :].values
|
126 |
+
filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
|
127 |
+
out = _forecast_tensor[mask] # slice forecasts to match filtered rows
|
128 |
+
inp = torch.tensor(filtered_data)
|
129 |
+
|
130 |
+
# If multiple series selected, create a subplot for each in a single figure
|
131 |
+
fig = go.Figure()
|
132 |
+
for idx in range(inp.shape[0]):
|
133 |
+
ts = inp[idx].numpy().tolist()
|
134 |
+
qp = out[idx].numpy()
|
135 |
+
series_name = filtered_names[idx]
|
136 |
+
x_hist = list(range(len(ts)))
|
137 |
+
# Historical data
|
138 |
+
fig.add_trace(go.Scatter(
|
139 |
+
x=x_hist,
|
140 |
+
y=ts,
|
141 |
+
mode='lines+markers',
|
142 |
+
name=f"{series_name} - Given Data"
|
143 |
+
))
|
144 |
+
# Quantiles
|
145 |
+
x_pred = list(range(len(ts) - 1, len(ts) - 1 + qp.shape[0]))
|
146 |
+
for i in range(qp.shape[1]):
|
147 |
+
fig.add_trace(go.Scatter(
|
148 |
+
x=x_pred,
|
149 |
+
y=qp[:, i],
|
150 |
+
mode='lines',
|
151 |
+
name=f"{series_name} - Quantile {i+1}",
|
152 |
+
opacity=0.6
|
153 |
+
))
|
154 |
+
|
155 |
+
fig.update_layout(
|
156 |
+
title=dict(text="Forecasts for Selected Timeseries", x=0.5, font=dict(size=16, family="Arial", color="#000")),
|
157 |
+
xaxis_title="Time",
|
158 |
+
yaxis_title="Value",
|
159 |
+
hovermode='x unified'
|
160 |
+
)
|
161 |
+
return fig, ""
|
162 |
+
except Exception as e:
|
163 |
+
return None, f"Error: {e}. Use CSV, XLS, XLSX, or PARQUET."
|
164 |
+
|
165 |
+
|
166 |
+
# ----------------------------
|
167 |
+
# Gradio layout: two columns + instructions
|
168 |
+
# ----------------------------
|
169 |
+
|
170 |
+
with gr.Blocks() as demo:
|
171 |
+
gr.Markdown("# 📈 TiRex - timeseries forecasting 📊")
|
172 |
+
gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.")
|
173 |
+
|
174 |
+
with gr.Row():
|
175 |
+
# Left column: controls
|
176 |
+
with gr.Column():
|
177 |
+
gr.Markdown("## Data Selection")
|
178 |
+
file_input = gr.File(
|
179 |
+
label="Upload CSV / XLSX / PARQUET",
|
180 |
+
file_types=[".csv", ".xls", ".xlsx", ".parquet"]
|
181 |
+
)
|
182 |
+
preset_choices = ["-- No preset selected --", "stocks_data_noindex.csv", "stocks_data.csv"]
|
183 |
+
|
184 |
+
preset_dropdown = gr.Dropdown(
|
185 |
+
label="Or choose a preset:",
|
186 |
+
choices=preset_choices,
|
187 |
+
value="-- No preset selected --"
|
188 |
+
)
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
gr.Markdown("## Search / Filter")
|
193 |
+
search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')")
|
194 |
+
filter_checkbox = gr.CheckboxGroup(
|
195 |
+
choices=[], value=[], label="Select which timeseries to show"
|
196 |
+
)
|
197 |
+
|
198 |
+
with gr.Row():
|
199 |
+
check_all_btn = gr.Button("✅ Check All")
|
200 |
+
uncheck_all_btn = gr.Button("❎ Uncheck All")
|
201 |
+
|
202 |
+
plot_button = gr.Button("▶️ Plot Forecasts")
|
203 |
+
errbox = gr.Textbox(label="Error Message", interactive=False)
|
204 |
+
with gr.Row():
|
205 |
+
gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False)
|
206 |
+
gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False)
|
207 |
+
|
208 |
+
# Right column: interactive plot + instructions
|
209 |
+
with gr.Column():
|
210 |
+
gr.Markdown("## Forecast Plot")
|
211 |
+
plot_output = gr.Plot()
|
212 |
+
|
213 |
+
# Instruction text below plot
|
214 |
+
gr.Markdown("## Instructions")
|
215 |
+
gr.Markdown(
|
216 |
+
"""
|
217 |
+
**How to format your data:**
|
218 |
+
- Your file must be a table (CSV, XLS, XLSX, or Parquet).
|
219 |
+
- **One row per timeseries.** Each row is treated as a separate series.
|
220 |
+
- If you want to **name** each series, put the name as the first value in **every** row:
|
221 |
+
- Example (CSV):
|
222 |
+
`AAPL, 120.5, 121.0, 119.8, ...`
|
223 |
+
`AMZN, 3300.0, 3310.5, 3295.2, ...`
|
224 |
+
- In that case, the first column is not numeric, so it will be used as the series name.
|
225 |
+
- If you do **not** want named series, simply leave out the first column entirely and have all values numeric:
|
226 |
+
- Example:
|
227 |
+
`120.5, 121.0, 119.8, ...`
|
228 |
+
`3300.0, 3310.5, 3295.2, ...`
|
229 |
+
- Then every row will be auto-named “Series 0, Series 1, …” in order.
|
230 |
+
- **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.
|
231 |
+
- The rest of the columns (after the optional name) must be numeric data points for that series.
|
232 |
+
- You can filter by typing in the search box. Then check or uncheck individual names before plotting.
|
233 |
+
- Use “Check All” / “Uncheck All” to quickly select or deselect every series.
|
234 |
+
- Finally, click **Plot Forecasts** to view the quantile forecast for each selected series (for 64 steps ahead).
|
235 |
+
"""
|
236 |
+
)
|
237 |
+
|
238 |
+
names_state = gr.State([])
|
239 |
+
|
240 |
+
# When file or preset changes, update names
|
241 |
+
file_input.change(
|
242 |
+
fn=extract_names_and_update,
|
243 |
+
inputs=[file_input, preset_dropdown],
|
244 |
+
outputs=[filter_checkbox, names_state]
|
245 |
+
)
|
246 |
+
preset_dropdown.change(
|
247 |
+
fn=extract_names_and_update,
|
248 |
+
inputs=[file_input, preset_dropdown],
|
249 |
+
outputs=[filter_checkbox, names_state]
|
250 |
+
)
|
251 |
+
|
252 |
+
# When search term changes, filter names
|
253 |
+
search_box.change(
|
254 |
+
fn=filter_names,
|
255 |
+
inputs=[search_box, names_state],
|
256 |
+
outputs=[filter_checkbox]
|
257 |
+
)
|
258 |
+
|
259 |
+
# Check All / Uncheck All
|
260 |
+
check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)
|
261 |
+
uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)
|
262 |
+
|
263 |
+
# Plot button
|
264 |
+
plot_button.click(
|
265 |
+
fn=display_filtered_forecast,
|
266 |
+
inputs=[file_input, preset_dropdown, filter_checkbox],
|
267 |
+
outputs=[plot_output, errbox]
|
268 |
+
)
|
269 |
+
demo.launch()
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
# '''
|
277 |
+
# 1. Prepared datasets
|
278 |
+
# 2. Plots of different quiantilies (different colors)
|
279 |
+
# 3. Filters for plots...
|
280 |
+
# 4. Different input options
|
281 |
+
# 5. README.md in there (in UI) (contact us for fine-tuning)
|
282 |
+
# 6. Requirements for dimensions
|
283 |
+
# 7. LOGO of NX-AI and xLSTM and tirex
|
284 |
+
# 8. *Range of prediction length customizable
|
285 |
+
# 9. *Multivariate data (x_t is vector)
|
286 |
+
# '''
|
static/nxai_logo.png
ADDED
![]() |
static/tirex.jpeg
ADDED
![]() |
stocks_data.csv
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Ticker,2024-12-04,2024-12-05,2024-12-06,2024-12-09,2024-12-10,2024-12-11,2024-12-12,2024-12-13,2024-12-16,2024-12-17,2024-12-18,2024-12-19,2024-12-20,2024-12-23,2024-12-24,2024-12-26,2024-12-27,2024-12-30,2024-12-31,2025-01-02,2025-01-03,2025-01-06,2025-01-07,2025-01-08,2025-01-10,2025-01-13,2025-01-14,2025-01-15,2025-01-16,2025-01-17,2025-01-21,2025-01-22,2025-01-23,2025-01-24,2025-01-27,2025-01-28,2025-01-29,2025-01-30,2025-01-31,2025-02-03,2025-02-04,2025-02-05,2025-02-06,2025-02-07,2025-02-10,2025-02-11,2025-02-12,2025-02-13,2025-02-14,2025-02-18,2025-02-19,2025-02-20,2025-02-21,2025-02-24,2025-02-25,2025-02-26,2025-02-27,2025-02-28,2025-03-03,2025-03-04,2025-03-05,2025-03-06,2025-03-07,2025-03-10,2025-03-11,2025-03-12,2025-03-13,2025-03-14,2025-03-17,2025-03-18,2025-03-19,2025-03-20,2025-03-21,2025-03-24,2025-03-25,2025-03-26,2025-03-27,2025-03-28,2025-03-31,2025-04-01,2025-04-02,2025-04-03,2025-04-04,2025-04-07,2025-04-08,2025-04-09,2025-04-10,2025-04-11,2025-04-14,2025-04-15,2025-04-16,2025-04-17,2025-04-21,2025-04-22,2025-04-23,2025-04-24,2025-04-25,2025-04-28,2025-04-29,2025-04-30,2025-05-01,2025-05-02,2025-05-05,2025-05-06,2025-05-07,2025-05-08,2025-05-09,2025-05-12,2025-05-13,2025-05-14,2025-05-15,2025-05-16,2025-05-19,2025-05-20,2025-05-21,2025-05-22,2025-05-23,2025-05-27,2025-05-28,2025-05-29,2025-05-30,2025-06-02,2025-06-03,2025-06-04
|
2 |
+
AAPL,242.42520141601562,242.4551239013672,242.25559997558594,246.1562042236328,247.1737518310547,245.89682006835938,247.36329650878906,247.5328826904297,250.4358673095703,252.8699951171875,247.4530792236328,249.1888885498047,253.87759399414062,254.6557159423828,257.57867431640625,258.39666748046875,254.9749298095703,251.5930938720703,249.8173828125,243.26319885253906,242.7743682861328,244.41041564941406,241.62713623046875,242.11593627929688,236.280029296875,233.83592224121094,232.71861267089844,237.2975616455078,227.710693359375,229.4265594482422,222.10421752929688,223.29135131835938,223.1217803955078,222.24388122558594,229.3068389892578,237.68663024902344,238.78399658203125,237.0182342529297,235.4320831298828,227.4613037109375,232.23977661132812,231.9105682373047,232.6587677001953,227.08221435546875,227.3518524169922,232.3153533935547,236.55978393554688,241.21368408203125,244.2796630859375,244.14984130859375,244.54930114746094,245.508056640625,245.22842407226562,246.77639770507812,246.71646118164062,240.0452117919922,236.98922729492188,241.5232696533203,237.71826171875,235.6210174560547,235.4312744140625,235.0218048095703,238.7569122314453,227.1820831298828,220.55078125,216.69583129882812,209.4053955078125,213.21041870117188,213.71974182128906,212.4114532470703,214.95811462402344,213.81961059570312,217.98414611816406,220.44091796875,223.45697021484375,221.2398681640625,223.5568389892578,217.6146240234375,221.83909606933594,222.897705078125,223.59678649902344,202.9239044189453,188.13330078125,181.2223663330078,172.19419860839844,198.58958435058594,190.17062377929688,197.89048767089844,202.25477600097656,201.87527465820312,194.0155792236328,196.72203063964844,192.9070281982422,199.47842407226562,204.33206176757812,208.09710693359375,209.00592041015625,209.8647918701172,210.9333953857422,212.22171020507812,213.04063415527344,205.08106994628906,198.62953186035156,198.25001525878906,195.99298095703125,197.2313690185547,198.27000427246094,210.7899932861328,212.92999267578125,212.3300018310547,211.4499969482422,211.25999450683594,208.77999877929688,206.86000061035156,202.08999633789062,201.36000061035156,195.27000427246094,200.2100067138672,200.4199981689453,199.9499969482422,200.85000610351562,201.6999969482422,203.27000427246094,203.1501007080078
|
3 |
+
AMZN,218.16000366210938,220.5500030517578,227.02999877929688,226.08999633789062,225.0399932861328,230.25999450683594,228.97000122070312,227.4600067138672,232.92999267578125,231.14999389648438,220.52000427246094,223.2899932861328,224.9199981689453,225.05999755859375,229.0500030517578,227.0500030517578,223.75,221.3000030517578,219.38999938964844,220.22000122070312,224.19000244140625,227.61000061035156,222.11000061035156,222.1300048828125,218.94000244140625,218.4600067138672,217.75999450683594,223.35000610351562,220.66000366210938,225.94000244140625,230.7100067138672,235.00999450683594,235.4199981689453,234.85000610351562,235.4199981689453,238.14999389648438,237.07000732421875,234.63999938964844,237.67999267578125,237.4199981689453,242.05999755859375,236.1699981689453,238.8300018310547,229.14999389648438,233.13999938964844,232.75999450683594,228.92999267578125,230.3699951171875,228.67999267578125,226.64999389648438,226.6300048828125,222.8800048828125,216.5800018310547,212.7100067138672,212.8000030517578,214.35000610351562,208.74000549316406,212.27999877929688,205.02000427246094,203.8000030517578,208.36000061035156,200.6999969482422,199.25,194.5399932861328,196.58999633789062,198.88999938964844,193.88999938964844,197.9499969482422,195.74000549316406,192.82000732421875,195.5399932861328,194.9499969482422,196.2100067138672,203.25999450683594,205.7100067138672,201.1300048828125,201.36000061035156,192.72000122070312,190.25999450683594,192.1699981689453,196.00999450683594,178.41000366210938,171.0,175.25999450683594,170.66000366210938,191.10000610351562,181.22000122070312,184.8699951171875,182.1199951171875,179.58999633789062,174.3300018310547,172.61000061035156,167.32000732421875,173.17999267578125,180.60000610351562,186.5399932861328,188.99000549316406,187.6999969482422,187.38999938964844,184.4199981689453,190.1999969482422,189.97999572753906,186.35000610351562,185.00999450683594,188.7100067138672,192.0800018310547,193.05999755859375,208.63999938964844,211.3699951171875,210.25,205.1699981689453,205.58999633789062,206.16000366210938,204.07000732421875,201.1199951171875,203.10000610351562,200.99000549316406,206.02000427246094,204.72000122070312,205.6999969482422,205.00999450683594,206.64999389648438,205.7100067138672,207.4720001220703
|
4 |
+
GOOGL,173.9700164794922,172.24400329589844,174.30926513671875,175.1682586669922,184.9569854736328,195.1752166748047,191.7391815185547,189.6016387939453,196.43377685546875,195.1951904296875,188.18325805664062,188.2931365966797,191.18980407714844,194.40611267089844,195.8843994140625,195.375,192.5382537841797,191.02000427246094,189.08224487304688,189.2120819091797,191.56936645507812,196.64352416992188,195.26512145996094,193.7268829345703,191.81907653808594,190.79026794433594,189.4418182373047,195.32504272460938,192.68807983398438,195.77452087402344,197.82217407226562,198.1417999267578,197.7522430419922,199.9796905517578,191.58934020996094,195.07533264160156,195.18521118164062,200.638916015625,203.78530883789062,200.99850463867188,206.14259338378906,191.1099090576172,191.3795928955078,185.1267852783203,186.2554931640625,185.10682678222656,183.39878845214844,185.92587280273438,185.01690673828125,183.55859375,185.05686950683594,184.34768676757812,179.4533233642578,179.04379272460938,175.21820068359375,172.5312957763672,168.30616760253906,170.0841064453125,166.81787109375,170.72337341308594,172.8209686279297,172.1517333984375,173.66000366210938,165.8699951171875,164.0399932861328,167.11000061035156,162.75999450683594,165.49000549316406,164.2899932861328,160.6699981689453,163.88999938964844,162.8000030517578,163.99000549316406,167.67999267578125,170.55999755859375,165.05999755859375,162.24000549316406,154.3300018310547,154.63999938964844,157.07000732421875,157.0399932861328,150.72000122070312,145.60000610351562,146.75,144.6999969482422,158.7100067138672,152.82000732421875,157.13999938964844,159.07000732421875,156.30999755859375,153.3300018310547,151.16000366210938,147.6699981689453,151.47000122070312,155.35000610351562,159.27999877929688,161.9600067138672,160.61000061035156,160.16000366210938,158.8000030517578,161.3000030517578,164.02999877929688,164.2100067138672,163.22999572753906,151.3800048828125,154.27999877929688,152.75,158.4600067138672,159.52999877929688,165.3699951171875,163.9600067138672,166.19000244140625,166.5399932861328,163.97999572753906,168.55999755859375,170.8699951171875,168.47000122070312,172.89999389648438,172.36000061035156,171.86000061035156,171.74000549316406,169.02999877929688,166.17999267578125,167.78500366210938
|
5 |
+
JNJ,148.00625610351562,147.0718231201172,146.86526489257812,147.1505126953125,146.78656005859375,144.2389678955078,143.8455352783203,144.21929931640625,141.49465942382812,144.0127410888672,142.3799285888672,141.2290802001953,142.10450744628906,142.8914031982422,143.4619140625,143.19631958007812,142.67501831054688,140.9929962158203,142.25204467773438,141.661865234375,141.8291015625,141.30776977539062,143.83567810058594,139.94052124023438,139.7339630126953,142.10450744628906,142.3799285888672,142.5963134765625,145.3504638671875,144.62258911132812,145.7242431640625,142.8914031982422,144.2389678955078,144.41604614257812,150.38662719726562,147.91775512695312,148.6751251220703,150.36695861816406,149.65875244140625,149.3833465576172,150.97682189941406,152.1571807861328,150.99649047851562,150.6128692626953,151.71453857421875,153.57359313964844,152.71783447265625,154.67526245117188,153.59324645996094,153.67257690429688,156.5479278564453,158.3227081298828,160.92044067382812,162.34820556640625,164.67822265625,161.69381713867188,162.33828735351562,163.61732482910156,165.8581085205078,164.013916015625,163.71646118164062,164.42044067382812,165.27313232421875,166.2745361328125,164.45018005371094,161.4657745361328,161.60458374023438,161.4261016845703,161.4558563232422,162.85386657714844,161.60458374023438,161.6343231201172,162.2391357421875,161.9020233154297,159.6513214111328,160.34536743164062,161.74339294433594,162.31846618652344,164.4303436279297,151.94737243652344,154.0394287109375,158.46153259277344,151.9374542236328,149.3397216796875,148.72499084472656,149.68675231933594,147.42613220214844,150.44029235839844,153.0479278564453,152.31422424316406,152.6017608642578,156.13150024414062,155.58616638183594,156.40911865234375,154.05926513671875,153.6130828857422,153.2660675048828,154.02952575683594,154.5847625732422,154.98135375976562,153.1470947265625,154.79296875,153.6824951171875,153.15699768066406,155.96295166015625,154.33688354492188,152.90911865234375,152.82980346679688,147.17825317382812,145.11593627929688,148.3383026123047,150.04368591308594,151.19383239746094,152.3538818359375,151.87796020507812,151.31280517578125,151.63999938964844,153.25,152.42999267578125,153.5800018310547,155.2100067138672,155.39999389648438,154.4199981689453,153.4199981689453
|
6 |
+
JPM,240.6669921875,242.7236328125,244.58251953125,241.0723876953125,240.133056640625,240.7955322265625,238.81797790527344,237.245849609375,236.889892578125,235.68359375,227.78329467773438,230.34422302246094,234.93212890625,235.7132568359375,239.5892333984375,240.409912109375,238.4620361328125,236.6328125,237.0184326171875,237.30517578125,240.54833984375,239.3755645751953,241.6813507080078,241.6416015625,238.40155029296875,242.71499633789062,245.9550323486328,250.80516052246094,252.71340942382812,257.573486328125,261.4197692871094,261.2309265136719,264.3219299316406,263.21868896484375,264.2225341796875,265.504638671875,264.9480285644531,266.58795166015625,265.66363525390625,265.1766357421875,266.2997131347656,268.77447509765625,275.2048645019531,274.1116027832031,269.3807373046875,273.3065490722656,273.7637634277344,274.62841796875,274.8967590332031,278.2362060546875,277.5404968261719,265.16668701171875,262.62237548828125,259.7401123046875,255.82423400878906,257.20574951171875,257.4641418457031,263.02984619140625,259.0245361328125,248.718017578125,249.99017333984375,245.0307159423828,240.7967987060547,230.79840087890625,227.73724365234375,226.5048370361328,223.81143188476562,231.0170440673828,232.49790954589844,233.53155517578125,237.64620971679688,237.54681396484375,240.15078735351562,246.5414276123047,249.59263610839844,249.4932403564453,246.60104370117188,241.3633270263672,243.79832458496094,242.16836547851562,244.3151397705078,227.29000854492188,210.27999877929688,214.44000244140625,216.8699951171875,234.33999633789062,227.11000061035156,236.1999969482422,234.72000122070312,233.1300048828125,229.61000061035156,231.9600067138672,228.99000549316406,235.58999633789062,240.8800048828125,244.63999938964844,243.5500030517578,243.22000122070312,244.6199951171875,244.6199951171875,246.88999938964844,252.50999450683594,252.55999755859375,249.25,249.38999938964844,253.47000122070312,253.0800018310547,260.04998779296875,263.010009765625,265.6400146484375,267.489990234375,267.55999755859375,264.8800048828125,265.67999267578125,261.0400085449219,260.6700134277344,260.7099914550781,265.2900085449219,263.489990234375,264.3699951171875,264.0,264.6600036621094,266.2699890136719,265.06500244140625
|
7 |
+
META,612.7401733398438,607.8983764648438,622.7132568359375,612.530517578125,618.2708129882812,631.608154296875,629.7213745117188,619.299072265625,623.6851196289062,618.889404296875,596.6591796875,595.0405883789062,584.7297973632812,599.3167724609375,607.2097778320312,602.8136596679688,599.27685546875,590.7144165039062,584.9895629882812,598.7073364257812,604.092529296875,629.6398315429688,617.3407592773438,610.1771240234375,615.3125610351562,607.789306640625,593.7218017578125,616.5714721679688,610.756591796875,612.225341796875,615.9120483398438,622.94580078125,635.88427734375,646.9144287109375,659.29345703125,673.7305908203125,675.888671875,686.3893432617188,688.5673828125,696.840087890625,703.5640869140625,704.2434692382812,711.3571166992188,713.8848876953125,716.7623291015625,719.16015625,724.7352294921875,727.9124145507812,736.0151977539062,715.7332153320312,703.1444702148438,694.222412109375,682.9423828125,667.5361328125,656.9155883789062,673.1011962890625,657.6549072265625,667.6060791015625,654.4677124023438,639.4310913085938,655.886474609375,627.371826171875,625.1038208007812,597.4584350585938,605.171630859375,619.00927734375,590.114990234375,607.5999755859375,604.9000244140625,582.3599853515625,584.0599975585938,586.0,596.25,618.8499755859375,626.3099975585938,610.97998046875,602.5800170898438,576.739990234375,576.3599853515625,586.0,583.9299926757812,531.6199951171875,504.7300109863281,516.25,510.45001220703125,585.77001953125,546.2899780273438,543.5700073242188,531.47998046875,521.52001953125,502.30999755859375,501.4800109863281,484.6600036621094,500.2799987792969,520.27001953125,533.1500244140625,547.27001953125,549.739990234375,554.4400024414062,549.0,572.2100219726562,597.02001953125,599.27001953125,587.3099975585938,596.8099975585938,598.010009765625,592.489990234375,639.4299926757812,656.030029296875,659.3599853515625,643.8800048828125,640.3400268554688,640.4299926757812,637.0999755859375,635.5,636.5700073242188,627.0599975585938,642.3200073242188,643.5800170898438,645.0499877929688,647.489990234375,670.9000244140625,666.8499755859375,685.1599731445312
|
8 |
+
MSFT,435.7447204589844,440.9248046875,441.87115478515625,444.311767578125,441.6320495605469,447.2704162597656,447.83819580078125,445.5569763183594,449.86041259765625,452.7194519042969,435.71484375,435.356201171875,434.9278869628906,433.5830383300781,437.64739990234375,436.43206787109375,428.881103515625,423.2029113769531,419.88568115234375,416.97686767578125,421.7286071777344,426.21136474609375,420.7523498535156,422.9339599609375,417.345458984375,415.5921936035156,414.0780334472656,424.67724609375,422.953857421875,427.3868408203125,426.85888671875,444.4910888671875,444.9991149902344,442.3592834472656,432.8956604003906,445.4872741699219,440.6358947753906,413.4006042480469,413.4703369140625,409.3462219238281,410.7906494140625,411.7071228027344,414.2274475097656,408.1806945800781,410.6412048339844,409.86419677734375,407.473388671875,408.9676513671875,406.86572265625,408.07110595703125,413.1814270019531,415.3674621582031,407.4619445800781,403.2596740722656,397.17083740234375,398.99749755859375,391.8106994628906,396.26251220703125,387.778076171875,387.8978576660156,400.2851257324219,396.1627197265625,392.5892639160156,379.46337890625,379.7528381347656,382.5676574707031,378.0758972167969,387.84796142578125,387.98773193359375,382.81719970703125,387.1093444824219,386.131103515625,390.54302978515625,392.35968017578125,394.4358825683594,389.2554016113281,389.8642578125,378.1058349609375,374.7021179199219,381.4896545410156,381.43975830078125,372.42626953125,359.18060302734375,357.2042236328125,353.9102783203125,389.7744140625,380.65118408203125,387.7381896972656,387.0993347167969,385.0231628417969,370.92901611328125,367.1060485839844,358.4619140625,366.1478271484375,373.7039489746094,386.59027099609375,391.1319580078125,390.4432067871094,393.31793212890625,394.53570556640625,424.6204528808594,434.48236083984375,435.3707275390625,432.5159606933594,432.555908203125,437.3670654296875,437.9260559082031,448.4367370605469,448.3169860839844,452.1099853515625,453.1300048828125,454.2699890136719,458.8699951171875,458.1700134277344,452.57000732421875,454.8599853515625,450.17999267578125,460.69000244140625,457.3599853515625,458.67999267578125,460.3599853515625,461.9700012207031,462.9700012207031,464.19000244140625
|
9 |
+
NVDA,145.1166534423828,145.04666137695312,142.42689514160156,138.79722595214844,135.05758666992188,139.29718017578125,137.32736206054688,134.23765563964844,131.98785400390625,130.3780059814453,128.8981475830078,130.66796875,134.68760681152344,139.6571502685547,140.2071075439453,139.91712951660156,136.9973907470703,137.47735595703125,134.27764892578125,138.29727172851562,144.4567108154297,149.41624450683594,140.12710571289062,140.0971221923828,135.8975067138672,133.21774291992188,131.7478790283203,136.22747802734375,133.55772399902344,137.6973419189453,140.81704711914062,147.05648803710938,147.2064666748047,142.60687255859375,118.40910339355469,128.9781494140625,123.6886215209961,124.63853454589844,120.0589599609375,116.64927673339844,118.63909149169922,124.81851959228516,128.66815185546875,129.8280487060547,133.55772399902344,132.7877960205078,131.12794494628906,135.27755737304688,138.8372344970703,139.38717651367188,139.21719360351562,140.0971221923828,134.41763305664062,130.2680206298828,126.61835479736328,131.26792907714844,120.13895416259766,124.90850830078125,114.04950714111328,115.97933197021484,117.28921508789062,110.55982971191406,112.67964172363281,106.97016143798828,108.75,115.73999786376953,115.58000183105469,121.66999816894531,119.52999877929688,115.43000030517578,117.5199966430664,118.52999877929688,117.69999694824219,121.41000366210938,120.69000244140625,113.76000213623047,111.43000030517578,109.66999816894531,108.37999725341797,110.1500015258789,110.41999816894531,101.80000305175781,94.30999755859375,97.63999938964844,96.30000305175781,114.33000183105469,107.56999969482422,110.93000030517578,110.70999908447266,112.19999694824219,104.48999786376953,101.48999786376953,96.91000366210938,98.88999938964844,102.70999908447266,106.43000030517578,111.01000213623047,108.7300033569336,109.0199966430664,108.91999816894531,111.61000061035156,114.5,113.81999969482422,113.54000091552734,117.05999755859375,117.37000274658203,116.6500015258789,123.0,129.92999267578125,135.33999633789062,134.8300018310547,135.39999389648438,135.57000732421875,134.3800048828125,131.8000030517578,132.8300018310547,131.2899932861328,135.5,134.80999755859375,139.19000244140625,135.1300048828125,137.3800048828125,141.22000122070312,141.85499572753906
|
10 |
+
TSLA,357.92999267578125,369.489990234375,389.2200012207031,389.7900085449219,400.989990234375,424.7699890136719,418.1000061035156,436.2300109863281,463.0199890136719,479.8599853515625,440.1300048828125,436.1700134277344,421.05999755859375,430.6000061035156,462.2799987792969,454.1300048828125,431.6600036621094,417.4100036621094,403.8399963378906,379.2799987792969,410.44000244140625,411.04998779296875,394.3599853515625,394.94000244140625,394.739990234375,403.30999755859375,396.3599853515625,428.2200012207031,413.82000732421875,426.5,424.07000732421875,415.1099853515625,412.3800048828125,406.5799865722656,397.1499938964844,398.0899963378906,389.1000061035156,400.2799987792969,404.6000061035156,383.67999267578125,392.2099914550781,378.1700134277344,374.32000732421875,361.6199951171875,350.7300109863281,328.5,336.510009765625,355.94000244140625,355.8399963378906,354.1099853515625,360.55999755859375,354.3999938964844,337.79998779296875,330.5299987792969,302.79998779296875,290.79998779296875,281.95001220703125,292.9800109863281,284.6499938964844,272.0400085449219,279.1000061035156,263.45001220703125,262.6700134277344,222.14999389648438,230.5800018310547,248.08999633789062,240.67999267578125,249.97999572753906,238.00999450683594,225.30999755859375,235.86000061035156,236.25999450683594,248.7100067138672,278.3900146484375,288.1400146484375,272.05999755859375,273.1300048828125,263.54998779296875,259.1600036621094,268.4599914550781,282.760009765625,267.2799987792969,239.42999267578125,233.2899932861328,221.86000061035156,272.20001220703125,252.39999389648438,252.30999755859375,252.35000610351562,254.11000061035156,241.5500030517578,241.3699951171875,227.5,237.97000122070312,250.74000549316406,259.510009765625,284.95001220703125,285.8800048828125,292.0299987792969,282.1600036621094,280.5199890136719,287.2099914550781,280.260009765625,275.3500061035156,276.2200012207031,284.82000732421875,298.260009765625,318.3800048828125,334.07000732421875,347.67999267578125,342.82000732421875,349.9800109863281,342.0899963378906,343.82000732421875,334.6199951171875,341.0400085449219,339.3399963378906,362.8900146484375,356.8999938964844,358.42999267578125,346.4599914550781,342.69000244140625,344.2699890136719,334.6716003417969
|
11 |
+
V,308.866455078125,308.0491943359375,309.9727783203125,307.27178955078125,311.33819580078125,312.7434997558594,313.1820373535156,313.6903076171875,314.8365173339844,317.2384338378906,308.7468566894531,313.8298645019531,316.6504211425781,316.16204833984375,319.5805969238281,319.8397521972656,317.59722900390625,314.2584228515625,314.9859924316406,313.3514709472656,313.8597717285156,311.9960021972656,310.6305847167969,311.55743408203125,306.6837463378906,305.89642333984375,308.05914306640625,315.22515869140625,316.19195556640625,318.5540466308594,322.5506591796875,322.48089599609375,327.1153869628906,329.0987548828125,333.4242858886719,333.364501953125,334.75982666015625,341.9058837890625,340.6600646972656,344.66668701171875,343.9989013671875,348.27459716796875,346.3211364746094,346.85931396484375,350.05865478515625,350.13848876953125,350.9072265625,355.0403747558594,353.2233581542969,356.1385498046875,354.6410217285156,349.90887451171875,347.9521179199219,349.2799072265625,351.5062255859375,350.04864501953125,355.1501770019531,362.1086120605469,361.2200927734375,351.64599609375,352.0952453613281,343.5893859863281,344.7474670410156,340.913818359375,331.5893249511719,332.28814697265625,328.0052490234375,331.2498474121094,333.99530029296875,334.2149353027344,339.3064880371094,338.9371032714844,335.1034851074219,343.29986572265625,344.0486145019531,343.6093444824219,349.2799072265625,342.28155517578125,349.87890625,345.7757568359375,345.7557678222656,338.8273010253906,312.61083984375,311.80218505859375,307.7588806152344,331.8788146972656,324.07177734375,332.8471984863281,334.624267578125,335.2532043457031,330.7806396484375,329.0634765625,319.56927490234375,330.85052490234375,333.81561279296875,335.15338134765625,334.6142883300781,336.9504089355469,340.9537353515625,344.9271545410156,341.8822326660156,347.023681640625,348.06195068359375,347.1235046386719,349.26995849609375,350.68756103515625,351.9554748535156,355.260009765625,356.1400146484375,356.4599914550781,362.29998779296875,365.1199951171875,367.8999938964844,366.8399963378906,358.29998779296875,357.9700012207031,353.5400085449219,359.29998779296875,359.7300109863281,362.3999938964844,365.19000244140625,365.32000732421875,365.8599853515625,368.17999267578125
|
stocks_data_forecast.pt
ADDED
Binary file (24.3 kB). View file
|
|
stocks_data_noindex.csv
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-12-04,2024-12-05,2024-12-06,2024-12-09,2024-12-10,2024-12-11,2024-12-12,2024-12-13,2024-12-16,2024-12-17,2024-12-18,2024-12-19,2024-12-20,2024-12-23,2024-12-24,2024-12-26,2024-12-27,2024-12-30,2024-12-31,2025-01-02,2025-01-03,2025-01-06,2025-01-07,2025-01-08,2025-01-10,2025-01-13,2025-01-14,2025-01-15,2025-01-16,2025-01-17,2025-01-21,2025-01-22,2025-01-23,2025-01-24,2025-01-27,2025-01-28,2025-01-29,2025-01-30,2025-01-31,2025-02-03,2025-02-04,2025-02-05,2025-02-06,2025-02-07,2025-02-10,2025-02-11,2025-02-12,2025-02-13,2025-02-14,2025-02-18,2025-02-19,2025-02-20,2025-02-21,2025-02-24,2025-02-25,2025-02-26,2025-02-27,2025-02-28,2025-03-03,2025-03-04,2025-03-05,2025-03-06,2025-03-07,2025-03-10,2025-03-11,2025-03-12,2025-03-13,2025-03-14,2025-03-17,2025-03-18,2025-03-19,2025-03-20,2025-03-21,2025-03-24,2025-03-25,2025-03-26,2025-03-27,2025-03-28,2025-03-31,2025-04-01,2025-04-02,2025-04-03,2025-04-04,2025-04-07,2025-04-08,2025-04-09,2025-04-10,2025-04-11,2025-04-14,2025-04-15,2025-04-16,2025-04-17,2025-04-21,2025-04-22,2025-04-23,2025-04-24,2025-04-25,2025-04-28,2025-04-29,2025-04-30,2025-05-01,2025-05-02,2025-05-05,2025-05-06,2025-05-07,2025-05-08,2025-05-09,2025-05-12,2025-05-13,2025-05-14,2025-05-15,2025-05-16,2025-05-19,2025-05-20,2025-05-21,2025-05-22,2025-05-23,2025-05-27,2025-05-28,2025-05-29,2025-05-30,2025-06-02,2025-06-03,2025-06-04
|
2 |
+
242.42520141601562,242.4551239013672,242.25559997558594,246.1562042236328,247.1737518310547,245.89683532714844,247.36329650878906,247.5328826904297,250.43588256835938,252.8699951171875,247.4530792236328,249.1888885498047,253.87759399414062,254.6557159423828,257.57867431640625,258.39666748046875,254.9749298095703,251.59307861328125,249.81736755371094,243.26319885253906,242.7743682861328,244.41041564941406,241.62713623046875,242.11595153808594,236.28004455566406,233.83592224121094,232.71861267089844,237.2975616455078,227.710693359375,229.4265594482422,222.10421752929688,223.29136657714844,223.1217803955078,222.24388122558594,229.30685424804688,237.68663024902344,238.7839813232422,237.0182342529297,235.4320831298828,227.46128845214844,232.23977661132812,231.91058349609375,232.6587677001953,227.08221435546875,227.3518524169922,232.3153533935547,236.55978393554688,241.21368408203125,244.2796630859375,244.14984130859375,244.54930114746094,245.508056640625,245.22842407226562,246.77639770507812,246.71646118164062,240.0452117919922,236.98922729492188,241.5232696533203,237.71826171875,235.6210174560547,235.4312744140625,235.0218048095703,238.7569122314453,227.1820831298828,220.55078125,216.69583129882812,209.4053955078125,213.21041870117188,213.71974182128906,212.4114532470703,214.95811462402344,213.81961059570312,217.98414611816406,220.44091796875,223.45697021484375,221.2398681640625,223.5568389892578,217.6146240234375,221.83909606933594,222.897705078125,223.59678649902344,202.9239044189453,188.13330078125,181.2223663330078,172.19419860839844,198.58958435058594,190.17062377929688,197.89048767089844,202.25477600097656,201.87527465820312,194.0155792236328,196.72203063964844,192.9070281982422,199.47842407226562,204.33206176757812,208.09710693359375,209.00592041015625,209.8647918701172,210.9333953857422,212.22171020507812,213.04063415527344,205.08106994628906,198.62953186035156,198.25001525878906,195.99298095703125,197.2313690185547,198.27000427246094,210.7899932861328,212.92999267578125,212.3300018310547,211.4499969482422,211.25999450683594,208.77999877929688,206.86000061035156,202.08999633789062,201.36000061035156,195.27000427246094,200.2100067138672,200.4199981689453,199.9499969482422,200.85000610351562,201.6999969482422,203.27000427246094,202.95260620117188
|
3 |
+
218.16000366210938,220.5500030517578,227.02999877929688,226.08999633789062,225.0399932861328,230.25999450683594,228.97000122070312,227.4600067138672,232.92999267578125,231.14999389648438,220.52000427246094,223.2899932861328,224.9199981689453,225.05999755859375,229.0500030517578,227.0500030517578,223.75,221.3000030517578,219.38999938964844,220.22000122070312,224.19000244140625,227.61000061035156,222.11000061035156,222.1300048828125,218.94000244140625,218.4600067138672,217.75999450683594,223.35000610351562,220.66000366210938,225.94000244140625,230.7100067138672,235.00999450683594,235.4199981689453,234.85000610351562,235.4199981689453,238.14999389648438,237.07000732421875,234.63999938964844,237.67999267578125,237.4199981689453,242.05999755859375,236.1699981689453,238.8300018310547,229.14999389648438,233.13999938964844,232.75999450683594,228.92999267578125,230.3699951171875,228.67999267578125,226.64999389648438,226.6300048828125,222.8800048828125,216.5800018310547,212.7100067138672,212.8000030517578,214.35000610351562,208.74000549316406,212.27999877929688,205.02000427246094,203.8000030517578,208.36000061035156,200.6999969482422,199.25,194.5399932861328,196.58999633789062,198.88999938964844,193.88999938964844,197.9499969482422,195.74000549316406,192.82000732421875,195.5399932861328,194.9499969482422,196.2100067138672,203.25999450683594,205.7100067138672,201.1300048828125,201.36000061035156,192.72000122070312,190.25999450683594,192.1699981689453,196.00999450683594,178.41000366210938,171.0,175.25999450683594,170.66000366210938,191.10000610351562,181.22000122070312,184.8699951171875,182.1199951171875,179.58999633789062,174.3300018310547,172.61000061035156,167.32000732421875,173.17999267578125,180.60000610351562,186.5399932861328,188.99000549316406,187.6999969482422,187.38999938964844,184.4199981689453,190.1999969482422,189.97999572753906,186.35000610351562,185.00999450683594,188.7100067138672,192.0800018310547,193.05999755859375,208.63999938964844,211.3699951171875,210.25,205.1699981689453,205.58999633789062,206.16000366210938,204.07000732421875,201.1199951171875,203.10000610351562,200.99000549316406,206.02000427246094,204.72000122070312,205.6999969482422,205.00999450683594,206.64999389648438,205.7100067138672,207.14990234375
|
4 |
+
173.9700164794922,172.24398803710938,174.30926513671875,175.1682586669922,184.9569854736328,195.1752166748047,191.7391815185547,189.6016387939453,196.43377685546875,195.1951904296875,188.18325805664062,188.2931365966797,191.18980407714844,194.40611267089844,195.8843994140625,195.375,192.5382537841797,191.02000427246094,189.08224487304688,189.2120819091797,191.56936645507812,196.64352416992188,195.26512145996094,193.7268829345703,191.81907653808594,190.79026794433594,189.4418182373047,195.32504272460938,192.68807983398438,195.77452087402344,197.82217407226562,198.1417999267578,197.7522430419922,199.9796905517578,191.58934020996094,195.07533264160156,195.18521118164062,200.638916015625,203.78530883789062,200.99850463867188,206.14259338378906,191.1099090576172,191.3795928955078,185.1267852783203,186.2554931640625,185.10682678222656,183.39878845214844,185.92587280273438,185.01690673828125,183.55859375,185.05686950683594,184.34768676757812,179.4533233642578,179.04379272460938,175.21820068359375,172.5312957763672,168.30616760253906,170.0841064453125,166.81787109375,170.72337341308594,172.8209686279297,172.1517333984375,173.66000366210938,165.8699951171875,164.0399932861328,167.11000061035156,162.75999450683594,165.49000549316406,164.2899932861328,160.6699981689453,163.88999938964844,162.8000030517578,163.99000549316406,167.67999267578125,170.55999755859375,165.05999755859375,162.24000549316406,154.3300018310547,154.63999938964844,157.07000732421875,157.0399932861328,150.72000122070312,145.60000610351562,146.75,144.6999969482422,158.7100067138672,152.82000732421875,157.13999938964844,159.07000732421875,156.30999755859375,153.3300018310547,151.16000366210938,147.6699981689453,151.47000122070312,155.35000610351562,159.27999877929688,161.9600067138672,160.61000061035156,160.16000366210938,158.8000030517578,161.3000030517578,164.02999877929688,164.2100067138672,163.22999572753906,151.3800048828125,154.27999877929688,152.75,158.4600067138672,159.52999877929688,165.3699951171875,163.9600067138672,166.19000244140625,166.5399932861328,163.97999572753906,168.55999755859375,170.8699951171875,168.47000122070312,172.89999389648438,172.36000061035156,171.86000061035156,171.74000549316406,169.02999877929688,166.17999267578125,167.97000122070312
|
5 |
+
148.00625610351562,147.0718231201172,146.86526489257812,147.1505126953125,146.78656005859375,144.2389678955078,143.8455352783203,144.21929931640625,141.49465942382812,144.0127410888672,142.3799285888672,141.2290802001953,142.10450744628906,142.8914031982422,143.4619140625,143.19631958007812,142.67501831054688,140.9929962158203,142.25204467773438,141.661865234375,141.8291015625,141.30776977539062,143.83567810058594,139.94052124023438,139.7339630126953,142.10450744628906,142.3799285888672,142.5963134765625,145.3504638671875,144.62258911132812,145.7242431640625,142.8914031982422,144.2389678955078,144.41604614257812,150.38662719726562,147.91775512695312,148.6751251220703,150.36695861816406,149.65875244140625,149.3833465576172,150.97682189941406,152.1571807861328,150.99649047851562,150.6128692626953,151.71453857421875,153.57359313964844,152.71783447265625,154.67526245117188,153.59324645996094,153.67257690429688,156.5479278564453,158.3227081298828,160.92044067382812,162.34820556640625,164.67822265625,161.69381713867188,162.33828735351562,163.61732482910156,165.8581085205078,164.013916015625,163.71646118164062,164.42044067382812,165.27313232421875,166.2745361328125,164.45018005371094,161.4657745361328,161.60458374023438,161.4261016845703,161.4558563232422,162.85386657714844,161.60458374023438,161.6343231201172,162.2391357421875,161.9020233154297,159.6513214111328,160.34536743164062,161.74339294433594,162.31846618652344,164.4303436279297,151.94737243652344,154.0394287109375,158.46153259277344,151.9374542236328,149.3397216796875,148.72499084472656,149.68675231933594,147.42613220214844,150.44029235839844,153.0479278564453,152.31422424316406,152.6017608642578,156.13150024414062,155.58616638183594,156.40911865234375,154.05926513671875,153.6130828857422,153.2660675048828,154.02952575683594,154.5847625732422,154.98135375976562,153.1470947265625,154.79296875,153.6824951171875,153.15699768066406,155.96295166015625,154.33688354492188,152.90911865234375,152.82980346679688,147.17825317382812,145.11593627929688,148.3383026123047,150.04368591308594,151.19383239746094,152.3538818359375,151.87796020507812,151.31280517578125,151.63999938964844,153.25,152.42999267578125,153.5800018310547,155.2100067138672,155.39999389648438,154.4199981689453,153.6750030517578
|
6 |
+
240.6669921875,242.7236328125,244.58251953125,241.0723876953125,240.133056640625,240.7955322265625,238.81797790527344,237.245849609375,236.889892578125,235.68359375,227.78329467773438,230.34422302246094,234.93212890625,235.7132568359375,239.5892333984375,240.409912109375,238.4620361328125,236.6328125,237.0184326171875,237.30517578125,240.54833984375,239.3755645751953,241.6813507080078,241.6416015625,238.40155029296875,242.71499633789062,245.9550323486328,250.80516052246094,252.71340942382812,257.573486328125,261.4197692871094,261.2309265136719,264.3219299316406,263.21868896484375,264.2225341796875,265.504638671875,264.9480285644531,266.58795166015625,265.66363525390625,265.1766357421875,266.2997131347656,268.77447509765625,275.2048645019531,274.1116027832031,269.3807373046875,273.3065490722656,273.7637634277344,274.62841796875,274.8967590332031,278.2362060546875,277.5404968261719,265.16668701171875,262.62237548828125,259.7401123046875,255.82423400878906,257.20574951171875,257.4641418457031,263.02984619140625,259.0245361328125,248.718017578125,249.99017333984375,245.0307159423828,240.7967987060547,230.79840087890625,227.73724365234375,226.5048370361328,223.81143188476562,231.0170440673828,232.49790954589844,233.53155517578125,237.64620971679688,237.54681396484375,240.15078735351562,246.5414276123047,249.59263610839844,249.4932403564453,246.60104370117188,241.3633270263672,243.79832458496094,242.16836547851562,244.3151397705078,227.29000854492188,210.27999877929688,214.44000244140625,216.8699951171875,234.33999633789062,227.11000061035156,236.1999969482422,234.72000122070312,233.1300048828125,229.61000061035156,231.9600067138672,228.99000549316406,235.58999633789062,240.8800048828125,244.63999938964844,243.5500030517578,243.22000122070312,244.6199951171875,244.6199951171875,246.88999938964844,252.50999450683594,252.55999755859375,249.25,249.38999938964844,253.47000122070312,253.0800018310547,260.04998779296875,263.010009765625,265.6400146484375,267.489990234375,267.55999755859375,264.8800048828125,265.67999267578125,261.0400085449219,260.6700134277344,260.7099914550781,265.2900085449219,263.489990234375,264.3699951171875,264.0,264.6600036621094,266.2699890136719,264.9649963378906
|
7 |
+
612.7401733398438,607.8983764648438,622.7132568359375,612.530517578125,618.2708129882812,631.6080932617188,629.7213134765625,619.2990112304688,623.6851196289062,618.889404296875,596.6591796875,595.0405883789062,584.7297973632812,599.3167724609375,607.2097778320312,602.8136596679688,599.27685546875,590.7144165039062,584.9895629882812,598.7073364257812,604.092529296875,629.6398315429688,617.3407592773438,610.1771240234375,615.3125610351562,607.789306640625,593.7218017578125,616.5714721679688,610.756591796875,612.225341796875,615.9120483398438,622.94580078125,635.88427734375,646.9144287109375,659.29345703125,673.7305908203125,675.888671875,686.3893432617188,688.5673828125,696.840087890625,703.5640869140625,704.2434692382812,711.3571166992188,713.8848876953125,716.7623291015625,719.16015625,724.7352294921875,727.9124145507812,736.0151977539062,715.7332153320312,703.1444702148438,694.222412109375,682.9423828125,667.5361328125,656.9155883789062,673.1011962890625,657.6549072265625,667.6060791015625,654.4677124023438,639.4310913085938,655.886474609375,627.371826171875,625.1038208007812,597.4584350585938,605.171630859375,619.00927734375,590.114990234375,607.5999755859375,604.9000244140625,582.3599853515625,584.0599975585938,586.0,596.25,618.8499755859375,626.3099975585938,610.97998046875,602.5800170898438,576.739990234375,576.3599853515625,586.0,583.9299926757812,531.6199951171875,504.7300109863281,516.25,510.45001220703125,585.77001953125,546.2899780273438,543.5700073242188,531.47998046875,521.52001953125,502.30999755859375,501.4800109863281,484.6600036621094,500.2799987792969,520.27001953125,533.1500244140625,547.27001953125,549.739990234375,554.4400024414062,549.0,572.2100219726562,597.02001953125,599.27001953125,587.3099975585938,596.8099975585938,598.010009765625,592.489990234375,639.4299926757812,656.030029296875,659.3599853515625,643.8800048828125,640.3400268554688,640.4299926757812,637.0999755859375,635.5,636.5700073242188,627.0599975585938,642.3200073242188,643.5800170898438,645.0499877929688,647.489990234375,670.9000244140625,666.8499755859375,685.2650146484375
|
8 |
+
435.7447204589844,440.9248046875,441.87115478515625,444.311767578125,441.6320495605469,447.2704162597656,447.83819580078125,445.5569763183594,449.86041259765625,452.7194519042969,435.71484375,435.356201171875,434.9278869628906,433.5830383300781,437.64739990234375,436.43206787109375,428.881103515625,423.2029113769531,419.88568115234375,416.97686767578125,421.7286071777344,426.21136474609375,420.7523498535156,422.9339599609375,417.345458984375,415.5921936035156,414.0780334472656,424.67724609375,422.953857421875,427.3868408203125,426.85888671875,444.4910888671875,444.9991149902344,442.3592834472656,432.8956604003906,445.4872741699219,440.6358947753906,413.4006042480469,413.4703369140625,409.3462219238281,410.7906494140625,411.7071228027344,414.2274475097656,408.1806945800781,410.6412048339844,409.86419677734375,407.473388671875,408.9676513671875,406.86572265625,408.07110595703125,413.1814270019531,415.3674621582031,407.4619445800781,403.2596740722656,397.17083740234375,398.99749755859375,391.8106994628906,396.26251220703125,387.778076171875,387.8978576660156,400.2851257324219,396.1627197265625,392.5892639160156,379.46337890625,379.7528381347656,382.5676574707031,378.0758972167969,387.84796142578125,387.98773193359375,382.81719970703125,387.1093444824219,386.131103515625,390.54302978515625,392.35968017578125,394.4358825683594,389.2554016113281,389.8642578125,378.1058349609375,374.7021179199219,381.4896545410156,381.43975830078125,372.42626953125,359.18060302734375,357.2042236328125,353.9102783203125,389.7744140625,380.65118408203125,387.7381896972656,387.0993347167969,385.0231628417969,370.92901611328125,367.1060485839844,358.4619140625,366.1478271484375,373.7039489746094,386.59027099609375,391.1319580078125,390.4432067871094,393.31793212890625,394.53570556640625,424.6204528808594,434.48236083984375,435.3707275390625,432.5159606933594,432.555908203125,437.3670654296875,437.9260559082031,448.4367370605469,448.3169860839844,452.1099853515625,453.1300048828125,454.2699890136719,458.8699951171875,458.1700134277344,452.57000732421875,454.8599853515625,450.17999267578125,460.69000244140625,457.3599853515625,458.67999267578125,460.3599853515625,461.9700012207031,462.9700012207031,464.260009765625
|
9 |
+
145.1166534423828,145.04666137695312,142.42689514160156,138.79722595214844,135.05758666992188,139.29718017578125,137.32736206054688,134.23765563964844,131.98785400390625,130.3780059814453,128.8981475830078,130.66796875,134.68760681152344,139.6571502685547,140.2071075439453,139.91712951660156,136.9973907470703,137.47735595703125,134.27764892578125,138.29727172851562,144.4567108154297,149.41624450683594,140.12710571289062,140.0971221923828,135.8975067138672,133.21774291992188,131.7478790283203,136.22747802734375,133.55772399902344,137.6973419189453,140.81704711914062,147.05648803710938,147.2064666748047,142.60687255859375,118.40910339355469,128.9781494140625,123.6886215209961,124.63853454589844,120.0589599609375,116.64927673339844,118.63909149169922,124.81851959228516,128.66815185546875,129.8280487060547,133.55772399902344,132.7877960205078,131.12794494628906,135.27755737304688,138.8372344970703,139.38717651367188,139.21719360351562,140.0971221923828,134.41763305664062,130.2680206298828,126.61835479736328,131.26792907714844,120.13895416259766,124.90850830078125,114.04950714111328,115.97933197021484,117.28921508789062,110.55982971191406,112.67964172363281,106.97016143798828,108.75,115.73999786376953,115.58000183105469,121.66999816894531,119.52999877929688,115.43000030517578,117.5199966430664,118.52999877929688,117.69999694824219,121.41000366210938,120.69000244140625,113.76000213623047,111.43000030517578,109.66999816894531,108.37999725341797,110.1500015258789,110.41999816894531,101.80000305175781,94.30999755859375,97.63999938964844,96.30000305175781,114.33000183105469,107.56999969482422,110.93000030517578,110.70999908447266,112.19999694824219,104.48999786376953,101.48999786376953,96.91000366210938,98.88999938964844,102.70999908447266,106.43000030517578,111.01000213623047,108.7300033569336,109.0199966430664,108.91999816894531,111.61000061035156,114.5,113.81999969482422,113.54000091552734,117.05999755859375,117.37000274658203,116.6500015258789,123.0,129.92999267578125,135.33999633789062,134.8300018310547,135.39999389648438,135.57000732421875,134.3800048828125,131.8000030517578,132.8300018310547,131.2899932861328,135.5,134.80999755859375,139.19000244140625,135.1300048828125,137.3800048828125,141.22000122070312,141.68099975585938
|
10 |
+
357.92999267578125,369.489990234375,389.2200012207031,389.7900085449219,400.989990234375,424.7699890136719,418.1000061035156,436.2300109863281,463.0199890136719,479.8599853515625,440.1300048828125,436.1700134277344,421.05999755859375,430.6000061035156,462.2799987792969,454.1300048828125,431.6600036621094,417.4100036621094,403.8399963378906,379.2799987792969,410.44000244140625,411.04998779296875,394.3599853515625,394.94000244140625,394.739990234375,403.30999755859375,396.3599853515625,428.2200012207031,413.82000732421875,426.5,424.07000732421875,415.1099853515625,412.3800048828125,406.5799865722656,397.1499938964844,398.0899963378906,389.1000061035156,400.2799987792969,404.6000061035156,383.67999267578125,392.2099914550781,378.1700134277344,374.32000732421875,361.6199951171875,350.7300109863281,328.5,336.510009765625,355.94000244140625,355.8399963378906,354.1099853515625,360.55999755859375,354.3999938964844,337.79998779296875,330.5299987792969,302.79998779296875,290.79998779296875,281.95001220703125,292.9800109863281,284.6499938964844,272.0400085449219,279.1000061035156,263.45001220703125,262.6700134277344,222.14999389648438,230.5800018310547,248.08999633789062,240.67999267578125,249.97999572753906,238.00999450683594,225.30999755859375,235.86000061035156,236.25999450683594,248.7100067138672,278.3900146484375,288.1400146484375,272.05999755859375,273.1300048828125,263.54998779296875,259.1600036621094,268.4599914550781,282.760009765625,267.2799987792969,239.42999267578125,233.2899932861328,221.86000061035156,272.20001220703125,252.39999389648438,252.30999755859375,252.35000610351562,254.11000061035156,241.5500030517578,241.3699951171875,227.5,237.97000122070312,250.74000549316406,259.510009765625,284.95001220703125,285.8800048828125,292.0299987792969,282.1600036621094,280.5199890136719,287.2099914550781,280.260009765625,275.3500061035156,276.2200012207031,284.82000732421875,298.260009765625,318.3800048828125,334.07000732421875,347.67999267578125,342.82000732421875,349.9800109863281,342.0899963378906,343.82000732421875,334.6199951171875,341.0400085449219,339.3399963378906,362.8900146484375,356.8999938964844,358.42999267578125,346.4599914550781,342.69000244140625,344.2699890136719,333.302001953125
|
11 |
+
308.866455078125,308.0491943359375,309.9727783203125,307.27178955078125,311.33819580078125,312.7434997558594,313.1820373535156,313.6903076171875,314.8365173339844,317.2384338378906,308.7468566894531,313.8298645019531,316.6504211425781,316.16204833984375,319.5805969238281,319.8397521972656,317.59722900390625,314.2584228515625,314.9859924316406,313.3514709472656,313.8597717285156,311.9960021972656,310.6305847167969,311.55743408203125,306.6837463378906,305.89642333984375,308.05914306640625,315.22515869140625,316.19195556640625,318.5540466308594,322.5506591796875,322.48089599609375,327.1153869628906,329.0987548828125,333.4242858886719,333.364501953125,334.75982666015625,341.9058837890625,340.6600646972656,344.66668701171875,343.9989013671875,348.27459716796875,346.3211364746094,346.85931396484375,350.05865478515625,350.13848876953125,350.9072265625,355.0403747558594,353.2233581542969,356.1385498046875,354.6410217285156,349.90887451171875,347.9521179199219,349.2799072265625,351.5062255859375,350.04864501953125,355.1501770019531,362.1086120605469,361.2200927734375,351.64599609375,352.0952453613281,343.5893859863281,344.7474670410156,340.913818359375,331.5893249511719,332.28814697265625,328.0052490234375,331.2498474121094,333.99530029296875,334.2149353027344,339.3064880371094,338.9371032714844,335.1034851074219,343.29986572265625,344.0486145019531,343.6093444824219,349.2799072265625,342.28155517578125,349.87890625,345.7757568359375,345.7557678222656,338.8273010253906,312.61083984375,311.80218505859375,307.7588806152344,331.8788146972656,324.07177734375,332.8471984863281,334.624267578125,335.2532043457031,330.7806396484375,329.0634765625,319.56927490234375,330.85052490234375,333.81561279296875,335.15338134765625,334.6142883300781,336.9504089355469,340.9537353515625,344.9271545410156,341.8822326660156,347.023681640625,348.06195068359375,347.1235046386719,349.26995849609375,350.68756103515625,351.9554748535156,355.260009765625,356.1400146484375,356.4599914550781,362.29998779296875,365.1199951171875,367.8999938964844,366.8399963378906,358.29998779296875,357.9700012207031,353.5400085449219,359.29998779296875,359.7300109863281,362.3999938964844,365.19000244140625,365.32000732421875,365.8599853515625,368.2650146484375
|
test.ipynb
ADDED
@@ -0,0 +1,1682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"[*********************100%***********************] 10 of 10 completed\n"
|
13 |
+
]
|
14 |
+
}
|
15 |
+
],
|
16 |
+
"source": [
|
17 |
+
"import yfinance as yf\n",
|
18 |
+
"import pandas as pd\n",
|
19 |
+
"\n",
|
20 |
+
"# List of 10 example tickers (you can replace these with any tickers you prefer)\n",
|
21 |
+
"tickers = [\"AAPL\", \"MSFT\", \"GOOGL\", \"AMZN\", \"META\", \"TSLA\", \"NVDA\", \"JPM\", \"JNJ\", \"V\"]\n",
|
22 |
+
"\n",
|
23 |
+
"# Download daily adjusted close prices for the last month (≈30 calendar days)\n",
|
24 |
+
"data = yf.download(tickers, period=\"6mo\", interval=\"1d\", auto_adjust=False)[\"Adj Close\"]\n",
|
25 |
+
"\n",
|
26 |
+
"# Transpose so that each row is one company’s month-long timeseries\n",
|
27 |
+
"df = pd.DataFrame(data.transpose())\n",
|
28 |
+
"\n",
|
29 |
+
"# At this point, `df` has:\n",
|
30 |
+
"# • Index: the 10 tickers\n",
|
31 |
+
"# • Columns: one column per trading day in the last month\n",
|
32 |
+
"# • Values: adjusted close price for that ticker on that date\n",
|
33 |
+
"df.to_csv(\"stocks_data_noindex.csv\", index=False)\n"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "markdown",
|
38 |
+
"metadata": {},
|
39 |
+
"source": [
|
40 |
+
"# Stable version"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"import io\n",
|
50 |
+
"import gradio as gr\n",
|
51 |
+
"import pandas as pd\n",
|
52 |
+
"import torch\n",
|
53 |
+
"import matplotlib.pyplot as plt\n",
|
54 |
+
"from PIL import Image\n",
|
55 |
+
"import numpy as np\n",
|
56 |
+
"\n",
|
57 |
+
"torch.manual_seed(42)\n",
|
58 |
+
"output = torch.load(\"stocks_data_forecast.pt\") # (n_timeseries, pred_len, n_quantiles)\n",
|
59 |
+
"\n",
|
60 |
+
"def model_forecast(input_data):\n",
|
61 |
+
" return output\n",
|
62 |
+
"\n",
|
63 |
+
"def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n",
|
64 |
+
" \"\"\"Returns a NumPy array of the plotted figure.\"\"\"\n",
|
65 |
+
" fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n",
|
66 |
+
" ax.plot(timeseries, color=\"blue\")\n",
|
67 |
+
" x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n",
|
68 |
+
" for i in range(quantile_predictions.shape[1]):\n",
|
69 |
+
" ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n",
|
70 |
+
" buf = io.BytesIO()\n",
|
71 |
+
"\n",
|
72 |
+
" # Add title\n",
|
73 |
+
" ax.set_title(f\"Timeseries: {timeseries_name}\")\n",
|
74 |
+
" # Add labels to the legend (quantiles)\n",
|
75 |
+
" labels = [f\"Quantile {i+1}\" for i in range(quantile_predictions.shape[1])]\n",
|
76 |
+
" ax.legend(labels, loc='center left', bbox_to_anchor=(1, 0.5))\n",
|
77 |
+
" plt.tight_layout(rect=[0, 0, 0.85, 1])\n",
|
78 |
+
"\n",
|
79 |
+
" fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n",
|
80 |
+
" plt.close(fig)\n",
|
81 |
+
" buf.seek(0)\n",
|
82 |
+
" img = Image.open(buf).convert(\"RGB\")\n",
|
83 |
+
" return np.array(img) # Return as an H×W×3 array\n",
|
84 |
+
"\n",
|
85 |
+
"def display_forecast(file, preset_filename):\n",
|
86 |
+
" accepted_formats = ['csv', 'xls', 'xlsx', 'parquet']\n",
|
87 |
+
"\n",
|
88 |
+
" def load_table(file_path):\n",
|
89 |
+
" ext = file_path.split('.')[-1].lower()\n",
|
90 |
+
" if ext == 'csv':\n",
|
91 |
+
" return pd.read_csv(file_path)\n",
|
92 |
+
" elif ext in ['xls', 'xlsx']:\n",
|
93 |
+
" return pd.read_excel(file_path)\n",
|
94 |
+
" elif ext == 'parquet':\n",
|
95 |
+
" return pd.read_parquet(file_path)\n",
|
96 |
+
" else:\n",
|
97 |
+
" raise ValueError(f\"Unsupported file format '.{ext}'. Acceptable formats: CSV, XLS, XLSX, PARQUET.\")\n",
|
98 |
+
" \n",
|
99 |
+
" try:\n",
|
100 |
+
" if file is not None:\n",
|
101 |
+
" df = load_table(file.name)\n",
|
102 |
+
" else:\n",
|
103 |
+
" if not preset_filename:\n",
|
104 |
+
" return [], \"Please upload a file or select a preset.\"\n",
|
105 |
+
" df = load_table(preset_filename)\n",
|
106 |
+
" \n",
|
107 |
+
" # Check first column for timeseries names\n",
|
108 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object:\n",
|
109 |
+
" if not df.iloc[:, 0].str.isnumeric().all():\n",
|
110 |
+
" timeseries_names = df.iloc[:, 0].tolist()\n",
|
111 |
+
" df = df.iloc[:, 1:]\n",
|
112 |
+
" else:\n",
|
113 |
+
" timeseries_names = [f\"Series {i}\" for i in range(len(df))]\n",
|
114 |
+
" else:\n",
|
115 |
+
" timeseries_names = [f\"Series {i}\" for i in range(len(df))]\n",
|
116 |
+
"\n",
|
117 |
+
" _input = torch.tensor(df.values)\n",
|
118 |
+
" _output = model_forecast(_input)\n",
|
119 |
+
"\n",
|
120 |
+
" gallery_images = []\n",
|
121 |
+
" for i in range(_input.shape[0]):\n",
|
122 |
+
" img_array = plot_forecast_image(_input[i], _output[i], timeseries_names[i])\n",
|
123 |
+
" gallery_images.append(img_array)\n",
|
124 |
+
"\n",
|
125 |
+
" return gallery_images, \"\"\n",
|
126 |
+
" except Exception as e:\n",
|
127 |
+
" return [], f\"Error: {e}. Please upload files in one of the following formats: CSV, XLS, XLSX, PARQUET.\"\n",
|
128 |
+
"\n",
|
129 |
+
"\n",
|
130 |
+
"\n",
|
131 |
+
"iface = gr.Interface(\n",
|
132 |
+
" fn=display_forecast,\n",
|
133 |
+
" inputs=[\n",
|
134 |
+
" gr.File(label=\"Upload your CSV file (optional)\"),\n",
|
135 |
+
" gr.Dropdown(\n",
|
136 |
+
" label=\"Or select a preset CSV file\",\n",
|
137 |
+
" choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n",
|
138 |
+
" value=\"stocks_data_noindex.csv\"\n",
|
139 |
+
" )\n",
|
140 |
+
" ],\n",
|
141 |
+
" outputs=[\n",
|
142 |
+
" gr.Gallery(label=\"Forecast Plots (one per row)\"), \n",
|
143 |
+
" gr.Textbox(label=\"Error Message\")\n",
|
144 |
+
" ],\n",
|
145 |
+
" title=\"CSV→Dynamic Forecast Gallery\",\n",
|
146 |
+
" description=\"Upload a CSV with any number of rows; each row’s forecast becomes one image in a gallery.\",\n",
|
147 |
+
" allow_flagging=\"never\",\n",
|
148 |
+
")\n",
|
149 |
+
"\n",
|
150 |
+
"if __name__ == \"__main__\":\n",
|
151 |
+
" iface.launch()\n",
|
152 |
+
"\n",
|
153 |
+
"\n",
|
154 |
+
"\n",
|
155 |
+
"# '''\n",
|
156 |
+
"# 1. Prepared datasets\n",
|
157 |
+
"# 2. Plots of different quiantilies (different colors)\n",
|
158 |
+
"# 3. Filters for plots...\n",
|
159 |
+
"# 4. Different input options\n",
|
160 |
+
"# 5. README.md in there (in UI) (contact us for fine-tuning)\n",
|
161 |
+
"# 6. Requirements for dimensions\n",
|
162 |
+
"# 7. Multivariate data (x_t is vector)\n",
|
163 |
+
"# 8. LOGO of NX-AI and xLSTM and tirex\n",
|
164 |
+
"# '''"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": 15,
|
170 |
+
"metadata": {},
|
171 |
+
"outputs": [
|
172 |
+
{
|
173 |
+
"data": {
|
174 |
+
"text/html": [
|
175 |
+
"<div>\n",
|
176 |
+
"<style scoped>\n",
|
177 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
178 |
+
" vertical-align: middle;\n",
|
179 |
+
" }\n",
|
180 |
+
"\n",
|
181 |
+
" .dataframe tbody tr th {\n",
|
182 |
+
" vertical-align: top;\n",
|
183 |
+
" }\n",
|
184 |
+
"\n",
|
185 |
+
" .dataframe thead th {\n",
|
186 |
+
" text-align: right;\n",
|
187 |
+
" }\n",
|
188 |
+
"</style>\n",
|
189 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
190 |
+
" <thead>\n",
|
191 |
+
" <tr style=\"text-align: right;\">\n",
|
192 |
+
" <th></th>\n",
|
193 |
+
" <th>Ticker</th>\n",
|
194 |
+
" <th>2024-12-04</th>\n",
|
195 |
+
" <th>2024-12-05</th>\n",
|
196 |
+
" <th>2024-12-06</th>\n",
|
197 |
+
" <th>2024-12-09</th>\n",
|
198 |
+
" <th>2024-12-10</th>\n",
|
199 |
+
" <th>2024-12-11</th>\n",
|
200 |
+
" <th>2024-12-12</th>\n",
|
201 |
+
" <th>2024-12-13</th>\n",
|
202 |
+
" <th>2024-12-16</th>\n",
|
203 |
+
" <th>...</th>\n",
|
204 |
+
" <th>2025-05-21</th>\n",
|
205 |
+
" <th>2025-05-22</th>\n",
|
206 |
+
" <th>2025-05-23</th>\n",
|
207 |
+
" <th>2025-05-27</th>\n",
|
208 |
+
" <th>2025-05-28</th>\n",
|
209 |
+
" <th>2025-05-29</th>\n",
|
210 |
+
" <th>2025-05-30</th>\n",
|
211 |
+
" <th>2025-06-02</th>\n",
|
212 |
+
" <th>2025-06-03</th>\n",
|
213 |
+
" <th>2025-06-04</th>\n",
|
214 |
+
" </tr>\n",
|
215 |
+
" </thead>\n",
|
216 |
+
" <tbody>\n",
|
217 |
+
" <tr>\n",
|
218 |
+
" <th>0</th>\n",
|
219 |
+
" <td>AAPL</td>\n",
|
220 |
+
" <td>242.425201</td>\n",
|
221 |
+
" <td>242.455124</td>\n",
|
222 |
+
" <td>242.255600</td>\n",
|
223 |
+
" <td>246.156204</td>\n",
|
224 |
+
" <td>247.173752</td>\n",
|
225 |
+
" <td>245.896820</td>\n",
|
226 |
+
" <td>247.363297</td>\n",
|
227 |
+
" <td>247.532883</td>\n",
|
228 |
+
" <td>250.435867</td>\n",
|
229 |
+
" <td>...</td>\n",
|
230 |
+
" <td>202.089996</td>\n",
|
231 |
+
" <td>201.360001</td>\n",
|
232 |
+
" <td>195.270004</td>\n",
|
233 |
+
" <td>200.210007</td>\n",
|
234 |
+
" <td>200.419998</td>\n",
|
235 |
+
" <td>199.949997</td>\n",
|
236 |
+
" <td>200.850006</td>\n",
|
237 |
+
" <td>201.699997</td>\n",
|
238 |
+
" <td>203.270004</td>\n",
|
239 |
+
" <td>203.150101</td>\n",
|
240 |
+
" </tr>\n",
|
241 |
+
" <tr>\n",
|
242 |
+
" <th>1</th>\n",
|
243 |
+
" <td>AMZN</td>\n",
|
244 |
+
" <td>218.160004</td>\n",
|
245 |
+
" <td>220.550003</td>\n",
|
246 |
+
" <td>227.029999</td>\n",
|
247 |
+
" <td>226.089996</td>\n",
|
248 |
+
" <td>225.039993</td>\n",
|
249 |
+
" <td>230.259995</td>\n",
|
250 |
+
" <td>228.970001</td>\n",
|
251 |
+
" <td>227.460007</td>\n",
|
252 |
+
" <td>232.929993</td>\n",
|
253 |
+
" <td>...</td>\n",
|
254 |
+
" <td>201.119995</td>\n",
|
255 |
+
" <td>203.100006</td>\n",
|
256 |
+
" <td>200.990005</td>\n",
|
257 |
+
" <td>206.020004</td>\n",
|
258 |
+
" <td>204.720001</td>\n",
|
259 |
+
" <td>205.699997</td>\n",
|
260 |
+
" <td>205.009995</td>\n",
|
261 |
+
" <td>206.649994</td>\n",
|
262 |
+
" <td>205.710007</td>\n",
|
263 |
+
" <td>207.472000</td>\n",
|
264 |
+
" </tr>\n",
|
265 |
+
" <tr>\n",
|
266 |
+
" <th>2</th>\n",
|
267 |
+
" <td>GOOGL</td>\n",
|
268 |
+
" <td>173.970016</td>\n",
|
269 |
+
" <td>172.244003</td>\n",
|
270 |
+
" <td>174.309265</td>\n",
|
271 |
+
" <td>175.168259</td>\n",
|
272 |
+
" <td>184.956985</td>\n",
|
273 |
+
" <td>195.175217</td>\n",
|
274 |
+
" <td>191.739182</td>\n",
|
275 |
+
" <td>189.601639</td>\n",
|
276 |
+
" <td>196.433777</td>\n",
|
277 |
+
" <td>...</td>\n",
|
278 |
+
" <td>168.559998</td>\n",
|
279 |
+
" <td>170.869995</td>\n",
|
280 |
+
" <td>168.470001</td>\n",
|
281 |
+
" <td>172.899994</td>\n",
|
282 |
+
" <td>172.360001</td>\n",
|
283 |
+
" <td>171.860001</td>\n",
|
284 |
+
" <td>171.740005</td>\n",
|
285 |
+
" <td>169.029999</td>\n",
|
286 |
+
" <td>166.179993</td>\n",
|
287 |
+
" <td>167.785004</td>\n",
|
288 |
+
" </tr>\n",
|
289 |
+
" <tr>\n",
|
290 |
+
" <th>3</th>\n",
|
291 |
+
" <td>JNJ</td>\n",
|
292 |
+
" <td>148.006256</td>\n",
|
293 |
+
" <td>147.071823</td>\n",
|
294 |
+
" <td>146.865265</td>\n",
|
295 |
+
" <td>147.150513</td>\n",
|
296 |
+
" <td>146.786560</td>\n",
|
297 |
+
" <td>144.238968</td>\n",
|
298 |
+
" <td>143.845535</td>\n",
|
299 |
+
" <td>144.219299</td>\n",
|
300 |
+
" <td>141.494659</td>\n",
|
301 |
+
" <td>...</td>\n",
|
302 |
+
" <td>151.877960</td>\n",
|
303 |
+
" <td>151.312805</td>\n",
|
304 |
+
" <td>151.639999</td>\n",
|
305 |
+
" <td>153.250000</td>\n",
|
306 |
+
" <td>152.429993</td>\n",
|
307 |
+
" <td>153.580002</td>\n",
|
308 |
+
" <td>155.210007</td>\n",
|
309 |
+
" <td>155.399994</td>\n",
|
310 |
+
" <td>154.419998</td>\n",
|
311 |
+
" <td>153.419998</td>\n",
|
312 |
+
" </tr>\n",
|
313 |
+
" <tr>\n",
|
314 |
+
" <th>4</th>\n",
|
315 |
+
" <td>JPM</td>\n",
|
316 |
+
" <td>240.666992</td>\n",
|
317 |
+
" <td>242.723633</td>\n",
|
318 |
+
" <td>244.582520</td>\n",
|
319 |
+
" <td>241.072388</td>\n",
|
320 |
+
" <td>240.133057</td>\n",
|
321 |
+
" <td>240.795532</td>\n",
|
322 |
+
" <td>238.817978</td>\n",
|
323 |
+
" <td>237.245850</td>\n",
|
324 |
+
" <td>236.889893</td>\n",
|
325 |
+
" <td>...</td>\n",
|
326 |
+
" <td>261.040009</td>\n",
|
327 |
+
" <td>260.670013</td>\n",
|
328 |
+
" <td>260.709991</td>\n",
|
329 |
+
" <td>265.290009</td>\n",
|
330 |
+
" <td>263.489990</td>\n",
|
331 |
+
" <td>264.369995</td>\n",
|
332 |
+
" <td>264.000000</td>\n",
|
333 |
+
" <td>264.660004</td>\n",
|
334 |
+
" <td>266.269989</td>\n",
|
335 |
+
" <td>265.065002</td>\n",
|
336 |
+
" </tr>\n",
|
337 |
+
" <tr>\n",
|
338 |
+
" <th>5</th>\n",
|
339 |
+
" <td>META</td>\n",
|
340 |
+
" <td>612.740173</td>\n",
|
341 |
+
" <td>607.898376</td>\n",
|
342 |
+
" <td>622.713257</td>\n",
|
343 |
+
" <td>612.530518</td>\n",
|
344 |
+
" <td>618.270813</td>\n",
|
345 |
+
" <td>631.608154</td>\n",
|
346 |
+
" <td>629.721375</td>\n",
|
347 |
+
" <td>619.299072</td>\n",
|
348 |
+
" <td>623.685120</td>\n",
|
349 |
+
" <td>...</td>\n",
|
350 |
+
" <td>635.500000</td>\n",
|
351 |
+
" <td>636.570007</td>\n",
|
352 |
+
" <td>627.059998</td>\n",
|
353 |
+
" <td>642.320007</td>\n",
|
354 |
+
" <td>643.580017</td>\n",
|
355 |
+
" <td>645.049988</td>\n",
|
356 |
+
" <td>647.489990</td>\n",
|
357 |
+
" <td>670.900024</td>\n",
|
358 |
+
" <td>666.849976</td>\n",
|
359 |
+
" <td>685.159973</td>\n",
|
360 |
+
" </tr>\n",
|
361 |
+
" <tr>\n",
|
362 |
+
" <th>6</th>\n",
|
363 |
+
" <td>MSFT</td>\n",
|
364 |
+
" <td>435.744720</td>\n",
|
365 |
+
" <td>440.924805</td>\n",
|
366 |
+
" <td>441.871155</td>\n",
|
367 |
+
" <td>444.311768</td>\n",
|
368 |
+
" <td>441.632050</td>\n",
|
369 |
+
" <td>447.270416</td>\n",
|
370 |
+
" <td>447.838196</td>\n",
|
371 |
+
" <td>445.556976</td>\n",
|
372 |
+
" <td>449.860413</td>\n",
|
373 |
+
" <td>...</td>\n",
|
374 |
+
" <td>452.570007</td>\n",
|
375 |
+
" <td>454.859985</td>\n",
|
376 |
+
" <td>450.179993</td>\n",
|
377 |
+
" <td>460.690002</td>\n",
|
378 |
+
" <td>457.359985</td>\n",
|
379 |
+
" <td>458.679993</td>\n",
|
380 |
+
" <td>460.359985</td>\n",
|
381 |
+
" <td>461.970001</td>\n",
|
382 |
+
" <td>462.970001</td>\n",
|
383 |
+
" <td>464.190002</td>\n",
|
384 |
+
" </tr>\n",
|
385 |
+
" <tr>\n",
|
386 |
+
" <th>7</th>\n",
|
387 |
+
" <td>NVDA</td>\n",
|
388 |
+
" <td>145.116653</td>\n",
|
389 |
+
" <td>145.046661</td>\n",
|
390 |
+
" <td>142.426895</td>\n",
|
391 |
+
" <td>138.797226</td>\n",
|
392 |
+
" <td>135.057587</td>\n",
|
393 |
+
" <td>139.297180</td>\n",
|
394 |
+
" <td>137.327362</td>\n",
|
395 |
+
" <td>134.237656</td>\n",
|
396 |
+
" <td>131.987854</td>\n",
|
397 |
+
" <td>...</td>\n",
|
398 |
+
" <td>131.800003</td>\n",
|
399 |
+
" <td>132.830002</td>\n",
|
400 |
+
" <td>131.289993</td>\n",
|
401 |
+
" <td>135.500000</td>\n",
|
402 |
+
" <td>134.809998</td>\n",
|
403 |
+
" <td>139.190002</td>\n",
|
404 |
+
" <td>135.130005</td>\n",
|
405 |
+
" <td>137.380005</td>\n",
|
406 |
+
" <td>141.220001</td>\n",
|
407 |
+
" <td>141.854996</td>\n",
|
408 |
+
" </tr>\n",
|
409 |
+
" <tr>\n",
|
410 |
+
" <th>8</th>\n",
|
411 |
+
" <td>TSLA</td>\n",
|
412 |
+
" <td>357.929993</td>\n",
|
413 |
+
" <td>369.489990</td>\n",
|
414 |
+
" <td>389.220001</td>\n",
|
415 |
+
" <td>389.790009</td>\n",
|
416 |
+
" <td>400.989990</td>\n",
|
417 |
+
" <td>424.769989</td>\n",
|
418 |
+
" <td>418.100006</td>\n",
|
419 |
+
" <td>436.230011</td>\n",
|
420 |
+
" <td>463.019989</td>\n",
|
421 |
+
" <td>...</td>\n",
|
422 |
+
" <td>334.619995</td>\n",
|
423 |
+
" <td>341.040009</td>\n",
|
424 |
+
" <td>339.339996</td>\n",
|
425 |
+
" <td>362.890015</td>\n",
|
426 |
+
" <td>356.899994</td>\n",
|
427 |
+
" <td>358.429993</td>\n",
|
428 |
+
" <td>346.459991</td>\n",
|
429 |
+
" <td>342.690002</td>\n",
|
430 |
+
" <td>344.269989</td>\n",
|
431 |
+
" <td>334.671600</td>\n",
|
432 |
+
" </tr>\n",
|
433 |
+
" <tr>\n",
|
434 |
+
" <th>9</th>\n",
|
435 |
+
" <td>V</td>\n",
|
436 |
+
" <td>308.866455</td>\n",
|
437 |
+
" <td>308.049194</td>\n",
|
438 |
+
" <td>309.972778</td>\n",
|
439 |
+
" <td>307.271790</td>\n",
|
440 |
+
" <td>311.338196</td>\n",
|
441 |
+
" <td>312.743500</td>\n",
|
442 |
+
" <td>313.182037</td>\n",
|
443 |
+
" <td>313.690308</td>\n",
|
444 |
+
" <td>314.836517</td>\n",
|
445 |
+
" <td>...</td>\n",
|
446 |
+
" <td>358.299988</td>\n",
|
447 |
+
" <td>357.970001</td>\n",
|
448 |
+
" <td>353.540009</td>\n",
|
449 |
+
" <td>359.299988</td>\n",
|
450 |
+
" <td>359.730011</td>\n",
|
451 |
+
" <td>362.399994</td>\n",
|
452 |
+
" <td>365.190002</td>\n",
|
453 |
+
" <td>365.320007</td>\n",
|
454 |
+
" <td>365.859985</td>\n",
|
455 |
+
" <td>368.179993</td>\n",
|
456 |
+
" </tr>\n",
|
457 |
+
" </tbody>\n",
|
458 |
+
"</table>\n",
|
459 |
+
"<p>10 rows × 125 columns</p>\n",
|
460 |
+
"</div>"
|
461 |
+
],
|
462 |
+
"text/plain": [
|
463 |
+
" Ticker 2024-12-04 2024-12-05 2024-12-06 2024-12-09 2024-12-10 \\\n",
|
464 |
+
"0 AAPL 242.425201 242.455124 242.255600 246.156204 247.173752 \n",
|
465 |
+
"1 AMZN 218.160004 220.550003 227.029999 226.089996 225.039993 \n",
|
466 |
+
"2 GOOGL 173.970016 172.244003 174.309265 175.168259 184.956985 \n",
|
467 |
+
"3 JNJ 148.006256 147.071823 146.865265 147.150513 146.786560 \n",
|
468 |
+
"4 JPM 240.666992 242.723633 244.582520 241.072388 240.133057 \n",
|
469 |
+
"5 META 612.740173 607.898376 622.713257 612.530518 618.270813 \n",
|
470 |
+
"6 MSFT 435.744720 440.924805 441.871155 444.311768 441.632050 \n",
|
471 |
+
"7 NVDA 145.116653 145.046661 142.426895 138.797226 135.057587 \n",
|
472 |
+
"8 TSLA 357.929993 369.489990 389.220001 389.790009 400.989990 \n",
|
473 |
+
"9 V 308.866455 308.049194 309.972778 307.271790 311.338196 \n",
|
474 |
+
"\n",
|
475 |
+
" 2024-12-11 2024-12-12 2024-12-13 2024-12-16 ... 2025-05-21 \\\n",
|
476 |
+
"0 245.896820 247.363297 247.532883 250.435867 ... 202.089996 \n",
|
477 |
+
"1 230.259995 228.970001 227.460007 232.929993 ... 201.119995 \n",
|
478 |
+
"2 195.175217 191.739182 189.601639 196.433777 ... 168.559998 \n",
|
479 |
+
"3 144.238968 143.845535 144.219299 141.494659 ... 151.877960 \n",
|
480 |
+
"4 240.795532 238.817978 237.245850 236.889893 ... 261.040009 \n",
|
481 |
+
"5 631.608154 629.721375 619.299072 623.685120 ... 635.500000 \n",
|
482 |
+
"6 447.270416 447.838196 445.556976 449.860413 ... 452.570007 \n",
|
483 |
+
"7 139.297180 137.327362 134.237656 131.987854 ... 131.800003 \n",
|
484 |
+
"8 424.769989 418.100006 436.230011 463.019989 ... 334.619995 \n",
|
485 |
+
"9 312.743500 313.182037 313.690308 314.836517 ... 358.299988 \n",
|
486 |
+
"\n",
|
487 |
+
" 2025-05-22 2025-05-23 2025-05-27 2025-05-28 2025-05-29 2025-05-30 \\\n",
|
488 |
+
"0 201.360001 195.270004 200.210007 200.419998 199.949997 200.850006 \n",
|
489 |
+
"1 203.100006 200.990005 206.020004 204.720001 205.699997 205.009995 \n",
|
490 |
+
"2 170.869995 168.470001 172.899994 172.360001 171.860001 171.740005 \n",
|
491 |
+
"3 151.312805 151.639999 153.250000 152.429993 153.580002 155.210007 \n",
|
492 |
+
"4 260.670013 260.709991 265.290009 263.489990 264.369995 264.000000 \n",
|
493 |
+
"5 636.570007 627.059998 642.320007 643.580017 645.049988 647.489990 \n",
|
494 |
+
"6 454.859985 450.179993 460.690002 457.359985 458.679993 460.359985 \n",
|
495 |
+
"7 132.830002 131.289993 135.500000 134.809998 139.190002 135.130005 \n",
|
496 |
+
"8 341.040009 339.339996 362.890015 356.899994 358.429993 346.459991 \n",
|
497 |
+
"9 357.970001 353.540009 359.299988 359.730011 362.399994 365.190002 \n",
|
498 |
+
"\n",
|
499 |
+
" 2025-06-02 2025-06-03 2025-06-04 \n",
|
500 |
+
"0 201.699997 203.270004 203.150101 \n",
|
501 |
+
"1 206.649994 205.710007 207.472000 \n",
|
502 |
+
"2 169.029999 166.179993 167.785004 \n",
|
503 |
+
"3 155.399994 154.419998 153.419998 \n",
|
504 |
+
"4 264.660004 266.269989 265.065002 \n",
|
505 |
+
"5 670.900024 666.849976 685.159973 \n",
|
506 |
+
"6 461.970001 462.970001 464.190002 \n",
|
507 |
+
"7 137.380005 141.220001 141.854996 \n",
|
508 |
+
"8 342.690002 344.269989 334.671600 \n",
|
509 |
+
"9 365.320007 365.859985 368.179993 \n",
|
510 |
+
"\n",
|
511 |
+
"[10 rows x 125 columns]"
|
512 |
+
]
|
513 |
+
},
|
514 |
+
"execution_count": 15,
|
515 |
+
"metadata": {},
|
516 |
+
"output_type": "execute_result"
|
517 |
+
}
|
518 |
+
],
|
519 |
+
"source": [
|
520 |
+
"pd.read_csv(\"stocks_data.csv\")"
|
521 |
+
]
|
522 |
+
},
|
523 |
+
{
|
524 |
+
"cell_type": "markdown",
|
525 |
+
"metadata": {},
|
526 |
+
"source": [
|
527 |
+
"# Not checked but with labels filter"
|
528 |
+
]
|
529 |
+
},
|
530 |
+
{
|
531 |
+
"cell_type": "code",
|
532 |
+
"execution_count": null,
|
533 |
+
"metadata": {},
|
534 |
+
"outputs": [],
|
535 |
+
"source": [
|
536 |
+
"import io\n",
|
537 |
+
"import pandas as pd\n",
|
538 |
+
"import torch\n",
|
539 |
+
"import matplotlib.pyplot as plt\n",
|
540 |
+
"from PIL import Image\n",
|
541 |
+
"import numpy as np\n",
|
542 |
+
"import gradio as gr\n",
|
543 |
+
"\n",
|
544 |
+
"# Set random seed and load your pretrained forecast tensor\n",
|
545 |
+
"torch.manual_seed(42)\n",
|
546 |
+
"_forecast_tensor = torch.load(\"stocks_data_forecast.pt\") # shape = (n_series, pred_len, n_q)\n",
|
547 |
+
"\n",
|
548 |
+
"def model_forecast(input_data):\n",
|
549 |
+
" return _forecast_tensor\n",
|
550 |
+
"\n",
|
551 |
+
"def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n",
|
552 |
+
" \"\"\"Given one 1D series + quantile‐matrix, return a NumPy array of the plotted figure.\"\"\"\n",
|
553 |
+
" fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n",
|
554 |
+
" ax.plot(timeseries, color=\"blue\")\n",
|
555 |
+
" x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n",
|
556 |
+
" for i in range(quantile_predictions.shape[1]):\n",
|
557 |
+
" ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n",
|
558 |
+
" ax.set_title(f\"Timeseries: {timeseries_name}\")\n",
|
559 |
+
"\n",
|
560 |
+
" labels = [f\"Quantile {i}\" for i in range(quantile_predictions.shape[1])]\n",
|
561 |
+
" ax.legend(labels, loc=\"center left\", bbox_to_anchor=(1, 0.5))\n",
|
562 |
+
" plt.tight_layout(rect=[0, 0, 0.85, 1])\n",
|
563 |
+
"\n",
|
564 |
+
" buf = io.BytesIO()\n",
|
565 |
+
" fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n",
|
566 |
+
" plt.close(fig)\n",
|
567 |
+
" buf.seek(0)\n",
|
568 |
+
" img = Image.open(buf).convert(\"RGB\")\n",
|
569 |
+
" return np.array(img)\n",
|
570 |
+
"\n",
|
571 |
+
"def load_table(file_path):\n",
|
572 |
+
" \"\"\"Load CSV / XLS(X) / Parquet by extension, else raise.\"\"\"\n",
|
573 |
+
" ext = file_path.split(\".\")[-1].lower()\n",
|
574 |
+
" if ext == \"csv\":\n",
|
575 |
+
" return pd.read_csv(file_path)\n",
|
576 |
+
" elif ext in (\"xls\", \"xlsx\"):\n",
|
577 |
+
" return pd.read_excel(file_path)\n",
|
578 |
+
" elif ext == \"parquet\":\n",
|
579 |
+
" return pd.read_parquet(file_path)\n",
|
580 |
+
" else:\n",
|
581 |
+
" raise ValueError(\n",
|
582 |
+
" f\"Unsupported file format '.{ext}'. Accepted: CSV, XLS, XLSX, PARQUET.\"\n",
|
583 |
+
" )\n",
|
584 |
+
"\n",
|
585 |
+
"def extract_names_and_update(file, preset_filename):\n",
|
586 |
+
" \"\"\"\n",
|
587 |
+
" Read the table (uploaded or preset), extract timeseries names, and return:\n",
|
588 |
+
" 1) gr.update for the CheckboxGroup (all names pre‐checked)\n",
|
589 |
+
" 2) the full list of names to store in state.\n",
|
590 |
+
" \"\"\"\n",
|
591 |
+
" try:\n",
|
592 |
+
" if file is not None:\n",
|
593 |
+
" df = load_table(file.name)\n",
|
594 |
+
" else:\n",
|
595 |
+
" if not preset_filename:\n",
|
596 |
+
" return gr.update(choices=[], value=[]), []\n",
|
597 |
+
" df = load_table(preset_filename)\n",
|
598 |
+
"\n",
|
599 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n",
|
600 |
+
" names = df.iloc[:, 0].tolist()\n",
|
601 |
+
" else:\n",
|
602 |
+
" names = [f\"Series {i}\" for i in range(len(df))]\n",
|
603 |
+
"\n",
|
604 |
+
" return gr.update(choices=names, value=names), names\n",
|
605 |
+
" except Exception:\n",
|
606 |
+
" return gr.update(choices=[], value=[]), []\n",
|
607 |
+
"\n",
|
608 |
+
"def filter_names(search_term, all_names):\n",
|
609 |
+
" \"\"\"\n",
|
610 |
+
" Filter the full list of names (all_names) by the search_term (case‐insensitive substring).\n",
|
611 |
+
" Return gr.update with filtered choices and keep checked those that remain in both.\n",
|
612 |
+
" \"\"\"\n",
|
613 |
+
" if not all_names:\n",
|
614 |
+
" return gr.update(choices=[], value=[])\n",
|
615 |
+
" if not search_term:\n",
|
616 |
+
" # No search term → show all\n",
|
617 |
+
" return gr.update(choices=all_names, value=all_names)\n",
|
618 |
+
" lower = search_term.lower()\n",
|
619 |
+
" filtered = [n for n in all_names if lower in str(n).lower()]\n",
|
620 |
+
" return gr.update(choices=filtered, value=filtered)\n",
|
621 |
+
"\n",
|
622 |
+
"def check_all(names_list):\n",
|
623 |
+
" \"\"\"Return an update that checks all names in the checkbox.\"\"\"\n",
|
624 |
+
" return gr.update(value=names_list)\n",
|
625 |
+
"\n",
|
626 |
+
"def uncheck_all(_):\n",
|
627 |
+
" \"\"\"Return an update that unchecks all names.\"\"\"\n",
|
628 |
+
" return gr.update(value=[])\n",
|
629 |
+
"\n",
|
630 |
+
"def display_filtered_forecast(file, preset_filename, selected_names):\n",
|
631 |
+
" \"\"\"\n",
|
632 |
+
" Load the table, filter by selected_names, run forecast, and return:\n",
|
633 |
+
" - list of images (NumPy arrays) for the gallery\n",
|
634 |
+
" - error string (empty if OK)\n",
|
635 |
+
" \"\"\"\n",
|
636 |
+
" try:\n",
|
637 |
+
" if file is not None:\n",
|
638 |
+
" df = load_table(file.name)\n",
|
639 |
+
" else:\n",
|
640 |
+
" if not preset_filename:\n",
|
641 |
+
" return [], \"No file selected.\"\n",
|
642 |
+
" df = load_table(preset_filename)\n",
|
643 |
+
"\n",
|
644 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n",
|
645 |
+
" all_names = df.iloc[:, 0].tolist()\n",
|
646 |
+
" data_only = df.iloc[:, 1:].astype(float)\n",
|
647 |
+
" else:\n",
|
648 |
+
" all_names = [f\"Series {i}\" for i in range(len(df))]\n",
|
649 |
+
" data_only = df.astype(float)\n",
|
650 |
+
"\n",
|
651 |
+
" mask = [name in selected_names for name in all_names]\n",
|
652 |
+
" if not any(mask):\n",
|
653 |
+
" return [], \"No timeseries chosen to plot.\"\n",
|
654 |
+
"\n",
|
655 |
+
" filtered_data = data_only.iloc[mask, :].values\n",
|
656 |
+
" filtered_names = [all_names[i] for i, m in enumerate(mask) if m]\n",
|
657 |
+
"\n",
|
658 |
+
" inp = torch.tensor(filtered_data) # (n_chosen, length)\n",
|
659 |
+
" out = model_forecast(inp) # (n_chosen, pred_len, n_q)\n",
|
660 |
+
"\n",
|
661 |
+
" gallery_images = []\n",
|
662 |
+
" for i in range(inp.shape[0]):\n",
|
663 |
+
" gallery_images.append(\n",
|
664 |
+
" plot_forecast_image(inp[i], out[i], filtered_names[i])\n",
|
665 |
+
" )\n",
|
666 |
+
"\n",
|
667 |
+
" return gallery_images, \"\"\n",
|
668 |
+
" except Exception as e:\n",
|
669 |
+
" return [], f\"Error: {e}. Please upload a valid CSV, XLS, XLSX, or PARQUET file.\"\n",
|
670 |
+
"\n",
|
671 |
+
"with gr.Blocks() as demo:\n",
|
672 |
+
" gr.Markdown(\"## Upload or select a preset → search/filter by name → click Plot\")\n",
|
673 |
+
"\n",
|
674 |
+
" with gr.Row():\n",
|
675 |
+
" file_input = gr.File(\n",
|
676 |
+
" label=\"Upload CSV/XLSX/PARQUET (optional)\",\n",
|
677 |
+
" file_types=[\".csv\", \".xls\", \".xlsx\", \".parquet\"]\n",
|
678 |
+
" )\n",
|
679 |
+
" preset_dropdown = gr.Dropdown(\n",
|
680 |
+
" label=\"Or pick a preset:\",\n",
|
681 |
+
" choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n",
|
682 |
+
" value=\"stocks_data_noindex.csv\"\n",
|
683 |
+
" )\n",
|
684 |
+
"\n",
|
685 |
+
" # A text box to type a substring (search term)\n",
|
686 |
+
" search_box = gr.Textbox(\n",
|
687 |
+
" label=\"Search/Filter timeseries by name\",\n",
|
688 |
+
" placeholder=\"Type to filter (e.g. 'AMZN')\",\n",
|
689 |
+
" value=\"\"\n",
|
690 |
+
" )\n",
|
691 |
+
"\n",
|
692 |
+
" # A CheckboxGroup to show matching names; choices/value will be updated dynamically\n",
|
693 |
+
" filter_checkbox = gr.CheckboxGroup(\n",
|
694 |
+
" choices=[], value=[], label=\"Select which timeseries to show\"\n",
|
695 |
+
" )\n",
|
696 |
+
"\n",
|
697 |
+
" # Buttons to check or uncheck all\n",
|
698 |
+
" with gr.Row():\n",
|
699 |
+
" check_all_btn = gr.Button(\"Check All\")\n",
|
700 |
+
" uncheck_all_btn = gr.Button(\"Uncheck All\")\n",
|
701 |
+
"\n",
|
702 |
+
" plot_button = gr.Button(\"Plot\")\n",
|
703 |
+
"\n",
|
704 |
+
" gallery = gr.Gallery(label=\"Forecast Plots (filtered)\")\n",
|
705 |
+
" errbox = gr.Textbox(label=\"Error Message\")\n",
|
706 |
+
"\n",
|
707 |
+
" # State to hold the full list of names\n",
|
708 |
+
" names_state = gr.State([])\n",
|
709 |
+
"\n",
|
710 |
+
" # 1) When file or preset changes, extract full names and update the checkbox + state\n",
|
711 |
+
" file_input.change(\n",
|
712 |
+
" fn=extract_names_and_update,\n",
|
713 |
+
" inputs=[file_input, preset_dropdown],\n",
|
714 |
+
" outputs=[filter_checkbox, names_state]\n",
|
715 |
+
" )\n",
|
716 |
+
" preset_dropdown.change(\n",
|
717 |
+
" fn=extract_names_and_update,\n",
|
718 |
+
" inputs=[file_input, preset_dropdown],\n",
|
719 |
+
" outputs=[filter_checkbox, names_state]\n",
|
720 |
+
" )\n",
|
721 |
+
"\n",
|
722 |
+
" # 2) When search text changes, filter names_state and update the checkbox\n",
|
723 |
+
" search_box.change(\n",
|
724 |
+
" fn=filter_names,\n",
|
725 |
+
" inputs=[search_box, names_state],\n",
|
726 |
+
" outputs=filter_checkbox\n",
|
727 |
+
" )\n",
|
728 |
+
"\n",
|
729 |
+
" # 3) Check All button: set checkbox value to all names in state\n",
|
730 |
+
" check_all_btn.click(\n",
|
731 |
+
" fn=check_all,\n",
|
732 |
+
" inputs=names_state,\n",
|
733 |
+
" outputs=filter_checkbox\n",
|
734 |
+
" )\n",
|
735 |
+
"\n",
|
736 |
+
" # 4) Uncheck All button: set checkbox value to empty list\n",
|
737 |
+
" uncheck_all_btn.click(\n",
|
738 |
+
" fn=uncheck_all,\n",
|
739 |
+
" inputs=names_state,\n",
|
740 |
+
" outputs=filter_checkbox\n",
|
741 |
+
" )\n",
|
742 |
+
"\n",
|
743 |
+
" # 5) When \"Plot\" is clicked, generate the filtered plots\n",
|
744 |
+
" plot_button.click(\n",
|
745 |
+
" fn=display_filtered_forecast,\n",
|
746 |
+
" inputs=[file_input, preset_dropdown, filter_checkbox],\n",
|
747 |
+
" outputs=[gallery, errbox],\n",
|
748 |
+
" )\n",
|
749 |
+
"\n",
|
750 |
+
"demo.launch()\n",
|
751 |
+
"\n"
|
752 |
+
]
|
753 |
+
},
|
754 |
+
{
|
755 |
+
"cell_type": "markdown",
|
756 |
+
"metadata": {},
|
757 |
+
"source": [
|
758 |
+
"# Checked, filter"
|
759 |
+
]
|
760 |
+
},
|
761 |
+
{
|
762 |
+
"cell_type": "code",
|
763 |
+
"execution_count": null,
|
764 |
+
"metadata": {},
|
765 |
+
"outputs": [],
|
766 |
+
"source": [
|
767 |
+
"import io\n",
|
768 |
+
"import pandas as pd\n",
|
769 |
+
"import torch\n",
|
770 |
+
"import matplotlib.pyplot as plt\n",
|
771 |
+
"from PIL import Image\n",
|
772 |
+
"import numpy as np\n",
|
773 |
+
"import gradio as gr\n",
|
774 |
+
"\n",
|
775 |
+
"# ----------------------------\n",
|
776 |
+
"# Helper functions (logic unchanged)\n",
|
777 |
+
"# ----------------------------\n",
|
778 |
+
"\n",
|
779 |
+
"torch.manual_seed(42)\n",
|
780 |
+
"_forecast_tensor = torch.load(\"stocks_data_forecast.pt\") # shape = (n_series, pred_len, n_q)\n",
|
781 |
+
"\n",
|
782 |
+
"def model_forecast(input_data):\n",
|
783 |
+
" return _forecast_tensor\n",
|
784 |
+
"\n",
|
785 |
+
"def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n",
|
786 |
+
" fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n",
|
787 |
+
" ax.plot(timeseries, color=\"blue\")\n",
|
788 |
+
" x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n",
|
789 |
+
" for i in range(quantile_predictions.shape[1]):\n",
|
790 |
+
" ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n",
|
791 |
+
" ax.set_title(f\"Timeseries: {timeseries_name}\")\n",
|
792 |
+
" labels = [f\"Quantile {i}\" for i in range(quantile_predictions.shape[1])]\n",
|
793 |
+
" ax.legend(labels, loc=\"center left\", bbox_to_anchor=(1, 0.5))\n",
|
794 |
+
" plt.tight_layout(rect=[0, 0, 0.85, 1])\n",
|
795 |
+
" buf = io.BytesIO()\n",
|
796 |
+
" fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n",
|
797 |
+
" plt.close(fig)\n",
|
798 |
+
" buf.seek(0)\n",
|
799 |
+
" img = Image.open(buf).convert(\"RGB\")\n",
|
800 |
+
" return np.array(img)\n",
|
801 |
+
"\n",
|
802 |
+
"def load_table(file_path):\n",
|
803 |
+
" ext = file_path.split(\".\")[-1].lower()\n",
|
804 |
+
" if ext == \"csv\":\n",
|
805 |
+
" return pd.read_csv(file_path)\n",
|
806 |
+
" elif ext in (\"xls\", \"xlsx\"):\n",
|
807 |
+
" return pd.read_excel(file_path)\n",
|
808 |
+
" elif ext == \"parquet\":\n",
|
809 |
+
" return pd.read_parquet(file_path)\n",
|
810 |
+
" else:\n",
|
811 |
+
" raise ValueError(\"Unsupported format. Use CSV, XLS, XLSX, or PARQUET.\")\n",
|
812 |
+
"\n",
|
813 |
+
"def extract_names_and_update(file, preset_filename):\n",
|
814 |
+
" try:\n",
|
815 |
+
" if file is not None:\n",
|
816 |
+
" df = load_table(file.name)\n",
|
817 |
+
" else:\n",
|
818 |
+
" if not preset_filename:\n",
|
819 |
+
" return gr.update(choices=[], value=[]), []\n",
|
820 |
+
" df = load_table(preset_filename)\n",
|
821 |
+
"\n",
|
822 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n",
|
823 |
+
" names = df.iloc[:, 0].tolist()\n",
|
824 |
+
" else:\n",
|
825 |
+
" names = [f\"Series {i}\" for i in range(len(df))]\n",
|
826 |
+
" return gr.update(choices=names, value=names), names\n",
|
827 |
+
" except Exception:\n",
|
828 |
+
" return gr.update(choices=[], value=[]), []\n",
|
829 |
+
"\n",
|
830 |
+
"def filter_names(search_term, all_names):\n",
|
831 |
+
" if not all_names:\n",
|
832 |
+
" return gr.update(choices=[], value=[])\n",
|
833 |
+
" if not search_term:\n",
|
834 |
+
" return gr.update(choices=all_names, value=all_names)\n",
|
835 |
+
" lower = search_term.lower()\n",
|
836 |
+
" filtered = [n for n in all_names if lower in str(n).lower()]\n",
|
837 |
+
" return gr.update(choices=filtered, value=filtered)\n",
|
838 |
+
"\n",
|
839 |
+
"def check_all(names_list):\n",
|
840 |
+
" return gr.update(value=names_list)\n",
|
841 |
+
"\n",
|
842 |
+
"def uncheck_all(_):\n",
|
843 |
+
" return gr.update(value=[])\n",
|
844 |
+
"\n",
|
845 |
+
"def display_filtered_forecast(file, preset_filename, selected_names):\n",
|
846 |
+
" \"\"\"\n",
|
847 |
+
" Load the table, filter by selected_names, run forecast (correctly sliced),\n",
|
848 |
+
" and return a gallery + error string.\n",
|
849 |
+
" \"\"\"\n",
|
850 |
+
" try:\n",
|
851 |
+
" if file is not None:\n",
|
852 |
+
" df = load_table(file.name)\n",
|
853 |
+
" else:\n",
|
854 |
+
" if not preset_filename:\n",
|
855 |
+
" return [], \"No file selected.\"\n",
|
856 |
+
" df = load_table(preset_filename)\n",
|
857 |
+
"\n",
|
858 |
+
" # Extract all_names and numeric data\n",
|
859 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object \\\n",
|
860 |
+
" and not df.iloc[:, 0].str.isnumeric().all():\n",
|
861 |
+
" all_names = df.iloc[:, 0].tolist()\n",
|
862 |
+
" data_only = df.iloc[:, 1:].astype(float)\n",
|
863 |
+
" else:\n",
|
864 |
+
" all_names = [f\"Series {i}\" for i in range(len(df))]\n",
|
865 |
+
" data_only = df.astype(float)\n",
|
866 |
+
"\n",
|
867 |
+
" # Build mask and filtered subset\n",
|
868 |
+
" mask = [name in selected_names for name in all_names]\n",
|
869 |
+
" if not any(mask):\n",
|
870 |
+
" return [], \"No timeseries chosen to plot.\"\n",
|
871 |
+
"\n",
|
872 |
+
" filtered_data = data_only.iloc[mask, :].values\n",
|
873 |
+
" filtered_names = [all_names[i] for i, m in enumerate(mask) if m]\n",
|
874 |
+
"\n",
|
875 |
+
" # ------------------------\n",
|
876 |
+
" # HERE is the only change:\n",
|
877 |
+
" # Instead of calling model_forecast(inp), slice the full tensor by mask:\n",
|
878 |
+
" # ------------------------\n",
|
879 |
+
" out = _forecast_tensor[mask] # shape = (n_chosen, pred_len, n_q)\n",
|
880 |
+
" inp = torch.tensor(filtered_data)\n",
|
881 |
+
"\n",
|
882 |
+
" # Plot each chosen series against its properly‐aligned forecast\n",
|
883 |
+
" gallery_images = []\n",
|
884 |
+
" for i in range(inp.shape[0]):\n",
|
885 |
+
" gallery_images.append(\n",
|
886 |
+
" plot_forecast_image(inp[i], out[i], filtered_names[i])\n",
|
887 |
+
" )\n",
|
888 |
+
"\n",
|
889 |
+
" return gallery_images, \"\"\n",
|
890 |
+
" except Exception as e:\n",
|
891 |
+
" return [], f\"Error: {e}. Please upload a valid CSV, XLS, XLSX, or PARQUET file.\"\n",
|
892 |
+
"\n",
|
893 |
+
"\n",
|
894 |
+
"# ----------------------------\n",
|
895 |
+
"# Gradio layout: two columns\n",
|
896 |
+
"# ----------------------------\n",
|
897 |
+
"\n",
|
898 |
+
"with gr.Blocks() as demo:\n",
|
899 |
+
" gr.Markdown(\"# 📈 Stock Forecast Viewer 📊\")\n",
|
900 |
+
" gr.Markdown(\"Upload data or choose a preset, filter by name, then click Plot.\")\n",
|
901 |
+
"\n",
|
902 |
+
" with gr.Row():\n",
|
903 |
+
" # Left column: controls\n",
|
904 |
+
" with gr.Column():\n",
|
905 |
+
" gr.Markdown(\"## Data Selection\")\n",
|
906 |
+
" file_input = gr.File(\n",
|
907 |
+
" label=\"Upload CSV / XLSX / PARQUET\",\n",
|
908 |
+
" file_types=[\".csv\", \".xls\", \".xlsx\", \".parquet\"]\n",
|
909 |
+
" )\n",
|
910 |
+
" preset_dropdown = gr.Dropdown(\n",
|
911 |
+
" label=\"Or choose a preset:\",\n",
|
912 |
+
" choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n",
|
913 |
+
" value=\"stocks_data_noindex.csv\"\n",
|
914 |
+
" )\n",
|
915 |
+
"\n",
|
916 |
+
" gr.Markdown(\"## Search / Filter\")\n",
|
917 |
+
" search_box = gr.Textbox(\n",
|
918 |
+
" placeholder=\"Type to filter (e.g. 'AMZN')\"\n",
|
919 |
+
" )\n",
|
920 |
+
" filter_checkbox = gr.CheckboxGroup(\n",
|
921 |
+
" choices=[], value=[], label=\"Select which timeseries to show\"\n",
|
922 |
+
" )\n",
|
923 |
+
"\n",
|
924 |
+
" with gr.Row():\n",
|
925 |
+
" check_all_btn = gr.Button(\"✅ Check All\")\n",
|
926 |
+
" uncheck_all_btn = gr.Button(\"❎ Uncheck All\")\n",
|
927 |
+
"\n",
|
928 |
+
" plot_button = gr.Button(\"▶️ Plot Forecasts\")\n",
|
929 |
+
" errbox = gr.Textbox(interactive=False, placeholder=\"\")\n",
|
930 |
+
"\n",
|
931 |
+
" # Right column: gallery\n",
|
932 |
+
" with gr.Column():\n",
|
933 |
+
" gr.Markdown(\"## Forecast Gallery\")\n",
|
934 |
+
" gallery = gr.Gallery()\n",
|
935 |
+
"\n",
|
936 |
+
" names_state = gr.State([])\n",
|
937 |
+
"\n",
|
938 |
+
" # When file or preset changes, update names\n",
|
939 |
+
" file_input.change(\n",
|
940 |
+
" fn=extract_names_and_update,\n",
|
941 |
+
" inputs=[file_input, preset_dropdown],\n",
|
942 |
+
" outputs=[filter_checkbox, names_state]\n",
|
943 |
+
" )\n",
|
944 |
+
" preset_dropdown.change(\n",
|
945 |
+
" fn=extract_names_and_update,\n",
|
946 |
+
" inputs=[file_input, preset_dropdown],\n",
|
947 |
+
" outputs=[filter_checkbox, names_state]\n",
|
948 |
+
" )\n",
|
949 |
+
"\n",
|
950 |
+
" # When search term changes, filter names\n",
|
951 |
+
" search_box.change(\n",
|
952 |
+
" fn=filter_names,\n",
|
953 |
+
" inputs=[search_box, names_state],\n",
|
954 |
+
" outputs=filter_checkbox\n",
|
955 |
+
" )\n",
|
956 |
+
"\n",
|
957 |
+
" # Check All / Uncheck All\n",
|
958 |
+
" check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)\n",
|
959 |
+
" uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)\n",
|
960 |
+
"\n",
|
961 |
+
" # Plot button\n",
|
962 |
+
" plot_button.click(\n",
|
963 |
+
" fn=display_filtered_forecast,\n",
|
964 |
+
" inputs=[file_input, preset_dropdown, filter_checkbox],\n",
|
965 |
+
" outputs=[gallery, errbox]\n",
|
966 |
+
" )\n",
|
967 |
+
"\n",
|
968 |
+
"demo.launch()"
|
969 |
+
]
|
970 |
+
},
|
971 |
+
{
|
972 |
+
"cell_type": "markdown",
|
973 |
+
"metadata": {},
|
974 |
+
"source": [
|
975 |
+
"# Checked, almost ideal"
|
976 |
+
]
|
977 |
+
},
|
978 |
+
{
|
979 |
+
"cell_type": "code",
|
980 |
+
"execution_count": null,
|
981 |
+
"metadata": {},
|
982 |
+
"outputs": [],
|
983 |
+
"source": [
|
984 |
+
"import io\n",
|
985 |
+
"import pandas as pd\n",
|
986 |
+
"import torch\n",
|
987 |
+
"import matplotlib.pyplot as plt\n",
|
988 |
+
"from PIL import Image\n",
|
989 |
+
"import numpy as np\n",
|
990 |
+
"import gradio as gr\n",
|
991 |
+
"\n",
|
992 |
+
"# ----------------------------\n",
|
993 |
+
"# Helper functions (logic unchanged)\n",
|
994 |
+
"# ----------------------------\n",
|
995 |
+
"\n",
|
996 |
+
"torch.manual_seed(42)\n",
|
997 |
+
"_forecast_tensor = torch.load(\"stocks_data_forecast.pt\") # shape = (n_series, pred_len, n_q)\n",
|
998 |
+
"\n",
|
999 |
+
"def model_forecast(input_data):\n",
|
1000 |
+
" return _forecast_tensor\n",
|
1001 |
+
"\n",
|
1002 |
+
"def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n",
|
1003 |
+
" fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n",
|
1004 |
+
" ax.plot(timeseries, color=\"blue\")\n",
|
1005 |
+
" x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n",
|
1006 |
+
" for i in range(quantile_predictions.shape[1]):\n",
|
1007 |
+
" ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n",
|
1008 |
+
" ax.set_title(f\"Timeseries: {timeseries_name}\")\n",
|
1009 |
+
" labels = [f\"Quantile {i}\" for i in range(quantile_predictions.shape[1])]\n",
|
1010 |
+
" ax.legend(labels, loc=\"center left\", bbox_to_anchor=(1, 0.5))\n",
|
1011 |
+
" plt.tight_layout(rect=[0, 0, 0.85, 1])\n",
|
1012 |
+
" buf = io.BytesIO()\n",
|
1013 |
+
" fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n",
|
1014 |
+
" plt.close(fig)\n",
|
1015 |
+
" buf.seek(0)\n",
|
1016 |
+
" img = Image.open(buf).convert(\"RGB\")\n",
|
1017 |
+
" return np.array(img)\n",
|
1018 |
+
"\n",
|
1019 |
+
"def load_table(file_path):\n",
|
1020 |
+
" ext = file_path.split(\".\")[-1].lower()\n",
|
1021 |
+
" if ext == \"csv\":\n",
|
1022 |
+
" return pd.read_csv(file_path)\n",
|
1023 |
+
" elif ext in (\"xls\", \"xlsx\"):\n",
|
1024 |
+
" return pd.read_excel(file_path)\n",
|
1025 |
+
" elif ext == \"parquet\":\n",
|
1026 |
+
" return pd.read_parquet(file_path)\n",
|
1027 |
+
" else:\n",
|
1028 |
+
" raise ValueError(\"Unsupported format. Use CSV, XLS, XLSX, or PARQUET.\")\n",
|
1029 |
+
"\n",
|
1030 |
+
"def extract_names_and_update(file, preset_filename):\n",
|
1031 |
+
" try:\n",
|
1032 |
+
" if file is not None:\n",
|
1033 |
+
" df = load_table(file.name)\n",
|
1034 |
+
" else:\n",
|
1035 |
+
" if not preset_filename:\n",
|
1036 |
+
" return gr.update(choices=[], value=[]), []\n",
|
1037 |
+
" df = load_table(preset_filename)\n",
|
1038 |
+
"\n",
|
1039 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n",
|
1040 |
+
" names = df.iloc[:, 0].tolist()\n",
|
1041 |
+
" else:\n",
|
1042 |
+
" names = [f\"Series {i}\" for i in range(len(df))]\n",
|
1043 |
+
" return gr.update(choices=names, value=names), names\n",
|
1044 |
+
" except Exception:\n",
|
1045 |
+
" return gr.update(choices=[], value=[]), []\n",
|
1046 |
+
"\n",
|
1047 |
+
"def filter_names(search_term, all_names):\n",
|
1048 |
+
" if not all_names:\n",
|
1049 |
+
" return gr.update(choices=[], value=[])\n",
|
1050 |
+
" if not search_term:\n",
|
1051 |
+
" return gr.update(choices=all_names, value=all_names)\n",
|
1052 |
+
" lower = search_term.lower()\n",
|
1053 |
+
" filtered = [n for n in all_names if lower in str(n).lower()]\n",
|
1054 |
+
" return gr.update(choices=filtered, value=filtered)\n",
|
1055 |
+
"\n",
|
1056 |
+
"def check_all(names_list):\n",
|
1057 |
+
" return gr.update(value=names_list)\n",
|
1058 |
+
"\n",
|
1059 |
+
"def uncheck_all(_):\n",
|
1060 |
+
" return gr.update(value=[])\n",
|
1061 |
+
"\n",
|
1062 |
+
"def display_filtered_forecast(file, preset_filename, selected_names):\n",
|
1063 |
+
" try:\n",
|
1064 |
+
" if file is not None:\n",
|
1065 |
+
" df = load_table(file.name)\n",
|
1066 |
+
" else:\n",
|
1067 |
+
" if not preset_filename:\n",
|
1068 |
+
" return [], \"No file selected.\"\n",
|
1069 |
+
" df = load_table(preset_filename)\n",
|
1070 |
+
"\n",
|
1071 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n",
|
1072 |
+
" all_names = df.iloc[:, 0].tolist()\n",
|
1073 |
+
" data_only = df.iloc[:, 1:].astype(float)\n",
|
1074 |
+
" else:\n",
|
1075 |
+
" all_names = [f\"Series {i}\" for i in range(len(df))]\n",
|
1076 |
+
" data_only = df.astype(float)\n",
|
1077 |
+
"\n",
|
1078 |
+
" mask = [name in selected_names for name in all_names]\n",
|
1079 |
+
" if not any(mask):\n",
|
1080 |
+
" return [], \"No timeseries chosen to plot.\"\n",
|
1081 |
+
"\n",
|
1082 |
+
" filtered_data = data_only.iloc[mask, :].values\n",
|
1083 |
+
" filtered_names = [all_names[i] for i, m in enumerate(mask) if m]\n",
|
1084 |
+
" out = _forecast_tensor[mask] # slice forecasts to match filtered rows\n",
|
1085 |
+
" inp = torch.tensor(filtered_data)\n",
|
1086 |
+
"\n",
|
1087 |
+
" gallery_images = []\n",
|
1088 |
+
" for i in range(inp.shape[0]):\n",
|
1089 |
+
" gallery_images.append(plot_forecast_image(inp[i], out[i], filtered_names[i]))\n",
|
1090 |
+
"\n",
|
1091 |
+
" return gallery_images, \"\"\n",
|
1092 |
+
" except Exception as e:\n",
|
1093 |
+
" return [], f\"Error: {e}. Use CSV, XLS, XLSX, or PARQUET.\"\n",
|
1094 |
+
"\n",
|
1095 |
+
"# ----------------------------\n",
|
1096 |
+
"# Gradio layout: two columns + instructions\n",
|
1097 |
+
"# ----------------------------\n",
|
1098 |
+
"\n",
|
1099 |
+
"with gr.Blocks() as demo:\n",
|
1100 |
+
" gr.Markdown(\"# 📈 Stock Forecast Viewer 📊\")\n",
|
1101 |
+
" gr.Markdown(\"Upload data or choose a preset, filter by name, then click Plot.\")\n",
|
1102 |
+
"\n",
|
1103 |
+
" with gr.Row():\n",
|
1104 |
+
" # Left column: controls\n",
|
1105 |
+
" with gr.Column():\n",
|
1106 |
+
" gr.Markdown(\"## Data Selection\")\n",
|
1107 |
+
" file_input = gr.File(\n",
|
1108 |
+
" label=\"Upload CSV / XLSX / PARQUET\",\n",
|
1109 |
+
" file_types=[\".csv\", \".xls\", \".xlsx\", \".parquet\"]\n",
|
1110 |
+
" )\n",
|
1111 |
+
" preset_dropdown = gr.Dropdown(\n",
|
1112 |
+
" label=\"Or choose a preset:\",\n",
|
1113 |
+
" choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n",
|
1114 |
+
" value=\"stocks_data_noindex.csv\"\n",
|
1115 |
+
" )\n",
|
1116 |
+
"\n",
|
1117 |
+
" gr.Markdown(\"## Search / Filter\")\n",
|
1118 |
+
" search_box = gr.Textbox(placeholder=\"Type to filter (e.g. 'AMZN')\")\n",
|
1119 |
+
" filter_checkbox = gr.CheckboxGroup(\n",
|
1120 |
+
" choices=[], value=[], label=\"Select which timeseries to show\"\n",
|
1121 |
+
" )\n",
|
1122 |
+
"\n",
|
1123 |
+
" with gr.Row():\n",
|
1124 |
+
" check_all_btn = gr.Button(\"✅ Check All\")\n",
|
1125 |
+
" uncheck_all_btn = gr.Button(\"❎ Uncheck All\")\n",
|
1126 |
+
"\n",
|
1127 |
+
" plot_button = gr.Button(\"▶️ Plot Forecasts\")\n",
|
1128 |
+
" errbox = gr.Textbox(interactive=False, placeholder=\"\")\n",
|
1129 |
+
"\n",
|
1130 |
+
" # Right column: gallery + instructions\n",
|
1131 |
+
" with gr.Column():\n",
|
1132 |
+
" gr.Markdown(\"## Forecast Gallery\")\n",
|
1133 |
+
" gallery = gr.Gallery()\n",
|
1134 |
+
"\n",
|
1135 |
+
" # Instruction text below gallery\n",
|
1136 |
+
" gr.Markdown(\n",
|
1137 |
+
" \"\"\"\n",
|
1138 |
+
" **How to format your data:**\n",
|
1139 |
+
" - Your file must be a table (CSV, XLS, XLSX, or Parquet).\n",
|
1140 |
+
" - If you haven't prepared the data, the preset file will be used.\n",
|
1141 |
+
" - **One row per timeseries.** Each row is treated as a separate series.\n",
|
1142 |
+
" - If you want to **name** each series, put the name as the first value in **every** row:\n",
|
1143 |
+
" - Example (CSV): \n",
|
1144 |
+
" `AAPL, 120.5, 121.0, 119.8, ...` \n",
|
1145 |
+
" `AMZN, 3300.0, 3310.5, 3295.2, ...` \n",
|
1146 |
+
" - In that case, the first column is not numeric, so it will be used as the series name.\n",
|
1147 |
+
" - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:\n",
|
1148 |
+
" - Example: \n",
|
1149 |
+
" `120.5, 121.0, 119.8, ...` \n",
|
1150 |
+
" `3300.0, 3310.5, 3295.2, ...` \n",
|
1151 |
+
" - Then every row will be auto-named “Series 0, Series 1, …” in order.\n",
|
1152 |
+
" - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.\n",
|
1153 |
+
" - The rest of the columns (after the optional name) must be numeric data points for that series.\n",
|
1154 |
+
" - You can filter by typing in the search box. Then check or uncheck individual names before plotting.\n",
|
1155 |
+
" - Use “Check All” / “Uncheck All” to quickly select or deselect every series.\n",
|
1156 |
+
" - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series.\n",
|
1157 |
+
" \"\"\"\n",
|
1158 |
+
" )\n",
|
1159 |
+
"\n",
|
1160 |
+
" names_state = gr.State([])\n",
|
1161 |
+
"\n",
|
1162 |
+
" # When file or preset changes, update names\n",
|
1163 |
+
" file_input.change(\n",
|
1164 |
+
" fn=extract_names_and_update,\n",
|
1165 |
+
" inputs=[file_input, preset_dropdown],\n",
|
1166 |
+
" outputs=[filter_checkbox, names_state]\n",
|
1167 |
+
" )\n",
|
1168 |
+
" preset_dropdown.change(\n",
|
1169 |
+
" fn=extract_names_and_update,\n",
|
1170 |
+
" inputs=[file_input, preset_dropdown],\n",
|
1171 |
+
" outputs=[filter_checkbox, names_state]\n",
|
1172 |
+
" )\n",
|
1173 |
+
"\n",
|
1174 |
+
" # When search term changes, filter names\n",
|
1175 |
+
" search_box.change(\n",
|
1176 |
+
" fn=filter_names,\n",
|
1177 |
+
" inputs=[search_box, names_state],\n",
|
1178 |
+
" outputs=filter_checkbox\n",
|
1179 |
+
" )\n",
|
1180 |
+
"\n",
|
1181 |
+
" # Check All / Uncheck All\n",
|
1182 |
+
" check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)\n",
|
1183 |
+
" uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)\n",
|
1184 |
+
"\n",
|
1185 |
+
" # Plot button\n",
|
1186 |
+
" plot_button.click(\n",
|
1187 |
+
" fn=display_filtered_forecast,\n",
|
1188 |
+
" inputs=[file_input, preset_dropdown, filter_checkbox],\n",
|
1189 |
+
" outputs=[gallery, errbox]\n",
|
1190 |
+
" )\n",
|
1191 |
+
"\n",
|
1192 |
+
"demo.launch()"
|
1193 |
+
]
|
1194 |
+
},
|
1195 |
+
{
|
1196 |
+
"cell_type": "markdown",
|
1197 |
+
"metadata": {},
|
1198 |
+
"source": [
|
1199 |
+
"# The default choice isn't processed when the default choice is chosen"
|
1200 |
+
]
|
1201 |
+
},
|
1202 |
+
{
|
1203 |
+
"cell_type": "markdown",
|
1204 |
+
"metadata": {},
|
1205 |
+
"source": []
|
1206 |
+
},
|
1207 |
+
{
|
1208 |
+
"cell_type": "code",
|
1209 |
+
"execution_count": null,
|
1210 |
+
"metadata": {},
|
1211 |
+
"outputs": [],
|
1212 |
+
"source": [
|
1213 |
+
"import io\n",
|
1214 |
+
"import pandas as pd\n",
|
1215 |
+
"import torch\n",
|
1216 |
+
"import matplotlib.pyplot as plt\n",
|
1217 |
+
"from PIL import Image\n",
|
1218 |
+
"import numpy as np\n",
|
1219 |
+
"import gradio as gr\n",
|
1220 |
+
"\n",
|
1221 |
+
"# ----------------------------\n",
|
1222 |
+
"# Helper functions (logic unchanged)\n",
|
1223 |
+
"# ----------------------------\n",
|
1224 |
+
"\n",
|
1225 |
+
"torch.manual_seed(42)\n",
|
1226 |
+
"_forecast_tensor = torch.load(\"stocks_data_forecast.pt\") # shape = (n_series, pred_len, n_q)\n",
|
1227 |
+
"\n",
|
1228 |
+
"def model_forecast(input_data):\n",
|
1229 |
+
" return _forecast_tensor\n",
|
1230 |
+
"\n",
|
1231 |
+
"def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n",
|
1232 |
+
" fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n",
|
1233 |
+
" ax.plot(timeseries, color=\"blue\")\n",
|
1234 |
+
" x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n",
|
1235 |
+
" for i in range(quantile_predictions.shape[1]):\n",
|
1236 |
+
" ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n",
|
1237 |
+
" ax.set_title(f\"Timeseries: {timeseries_name}\")\n",
|
1238 |
+
" labels = [f\"Quantile {i}\" for i in range(quantile_predictions.shape[1])]\n",
|
1239 |
+
" ax.legend(labels, loc=\"center left\", bbox_to_anchor=(1, 0.5))\n",
|
1240 |
+
" plt.tight_layout(rect=[0, 0, 0.85, 1])\n",
|
1241 |
+
" buf = io.BytesIO()\n",
|
1242 |
+
" fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n",
|
1243 |
+
" plt.close(fig)\n",
|
1244 |
+
" buf.seek(0)\n",
|
1245 |
+
" img = Image.open(buf).convert(\"RGB\")\n",
|
1246 |
+
" return np.array(img)\n",
|
1247 |
+
"\n",
|
1248 |
+
"def load_table(file_path):\n",
|
1249 |
+
" ext = file_path.split(\".\")[-1].lower()\n",
|
1250 |
+
" if ext == \"csv\":\n",
|
1251 |
+
" return pd.read_csv(file_path)\n",
|
1252 |
+
" elif ext in (\"xls\", \"xlsx\"):\n",
|
1253 |
+
" return pd.read_excel(file_path)\n",
|
1254 |
+
" elif ext == \"parquet\":\n",
|
1255 |
+
" return pd.read_parquet(file_path)\n",
|
1256 |
+
" else:\n",
|
1257 |
+
" raise ValueError(\"Unsupported format. Use CSV, XLS, XLSX, or PARQUET.\")\n",
|
1258 |
+
"\n",
|
1259 |
+
"def extract_names_and_update(file, preset_filename):\n",
|
1260 |
+
" try:\n",
|
1261 |
+
" if file is not None:\n",
|
1262 |
+
" df = load_table(file.name)\n",
|
1263 |
+
" else:\n",
|
1264 |
+
" if not preset_filename:\n",
|
1265 |
+
" return gr.update(choices=[], value=[]), []\n",
|
1266 |
+
" df = load_table(preset_filename)\n",
|
1267 |
+
"\n",
|
1268 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n",
|
1269 |
+
" names = df.iloc[:, 0].tolist()\n",
|
1270 |
+
" else:\n",
|
1271 |
+
" names = [f\"Series {i}\" for i in range(len(df))]\n",
|
1272 |
+
" return gr.update(choices=names, value=names), names\n",
|
1273 |
+
" except Exception:\n",
|
1274 |
+
" return gr.update(choices=[], value=[]), []\n",
|
1275 |
+
"\n",
|
1276 |
+
"def filter_names(search_term, all_names):\n",
|
1277 |
+
" if not all_names:\n",
|
1278 |
+
" return gr.update(choices=[], value=[])\n",
|
1279 |
+
" if not search_term:\n",
|
1280 |
+
" return gr.update(choices=all_names, value=all_names)\n",
|
1281 |
+
" lower = search_term.lower()\n",
|
1282 |
+
" filtered = [n for n in all_names if lower in str(n).lower()]\n",
|
1283 |
+
" return gr.update(choices=filtered, value=filtered)\n",
|
1284 |
+
"\n",
|
1285 |
+
"def check_all(names_list):\n",
|
1286 |
+
" return gr.update(value=names_list)\n",
|
1287 |
+
"\n",
|
1288 |
+
"def uncheck_all(_):\n",
|
1289 |
+
" return gr.update(value=[])\n",
|
1290 |
+
"\n",
|
1291 |
+
"def display_filtered_forecast(file, preset_filename, selected_names):\n",
|
1292 |
+
" try:\n",
|
1293 |
+
" if file is not None:\n",
|
1294 |
+
" df = load_table(file.name)\n",
|
1295 |
+
" else:\n",
|
1296 |
+
" if not preset_filename:\n",
|
1297 |
+
" return [], \"No file selected.\"\n",
|
1298 |
+
" df = load_table(preset_filename)\n",
|
1299 |
+
"\n",
|
1300 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n",
|
1301 |
+
" all_names = df.iloc[:, 0].tolist()\n",
|
1302 |
+
" data_only = df.iloc[:, 1:].astype(float)\n",
|
1303 |
+
" else:\n",
|
1304 |
+
" all_names = [f\"Series {i}\" for i in range(len(df))]\n",
|
1305 |
+
" data_only = df.astype(float)\n",
|
1306 |
+
"\n",
|
1307 |
+
" mask = [name in selected_names for name in all_names]\n",
|
1308 |
+
" if not any(mask):\n",
|
1309 |
+
" return [], \"No timeseries chosen to plot.\"\n",
|
1310 |
+
"\n",
|
1311 |
+
" filtered_data = data_only.iloc[mask, :].values\n",
|
1312 |
+
" filtered_names = [all_names[i] for i, m in enumerate(mask) if m]\n",
|
1313 |
+
" out = _forecast_tensor[mask] # slice forecasts to match filtered rows\n",
|
1314 |
+
" inp = torch.tensor(filtered_data)\n",
|
1315 |
+
"\n",
|
1316 |
+
" gallery_images = []\n",
|
1317 |
+
" for i in range(inp.shape[0]):\n",
|
1318 |
+
" gallery_images.append(plot_forecast_image(inp[i], out[i], filtered_names[i]))\n",
|
1319 |
+
"\n",
|
1320 |
+
" return gallery_images, \"\"\n",
|
1321 |
+
" except Exception as e:\n",
|
1322 |
+
" return [], f\"Error: {e}. Use CSV, XLS, XLSX, or PARQUET.\"\n",
|
1323 |
+
"\n",
|
1324 |
+
"\n",
|
1325 |
+
"# ----------------------------\n",
|
1326 |
+
"# Gradio layout: two columns + instructions\n",
|
1327 |
+
"# ----------------------------\n",
|
1328 |
+
"\n",
|
1329 |
+
"with gr.Blocks() as demo:\n",
|
1330 |
+
" gr.Markdown(\"# 📈 Stock Forecast Viewer 📊\")\n",
|
1331 |
+
" gr.Markdown(\"Upload data or choose a preset, filter by name, then click Plot.\")\n",
|
1332 |
+
"\n",
|
1333 |
+
" with gr.Row():\n",
|
1334 |
+
" # Left column: controls\n",
|
1335 |
+
" with gr.Column():\n",
|
1336 |
+
" gr.Markdown(\"## Data Selection\")\n",
|
1337 |
+
" gr.Markdown(\"*If you haven't prepared the data, the preset file will be used.*\")\n",
|
1338 |
+
" file_input = gr.File(\n",
|
1339 |
+
" label=\"Upload CSV / XLSX / PARQUET\",\n",
|
1340 |
+
" file_types=[\".csv\", \".xls\", \".xlsx\", \".parquet\"]\n",
|
1341 |
+
" )\n",
|
1342 |
+
" preset_dropdown = gr.Dropdown(\n",
|
1343 |
+
" label=\"Or choose a preset:\",\n",
|
1344 |
+
" choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n",
|
1345 |
+
" value=\"stocks_data_noindex.csv\"\n",
|
1346 |
+
" )\n",
|
1347 |
+
"\n",
|
1348 |
+
" gr.Markdown(\"## Search / Filter\")\n",
|
1349 |
+
" search_box = gr.Textbox(placeholder=\"Type to filter (e.g. 'AMZN')\")\n",
|
1350 |
+
" filter_checkbox = gr.CheckboxGroup(\n",
|
1351 |
+
" choices=[], value=[], label=\"Select which timeseries to show\"\n",
|
1352 |
+
" )\n",
|
1353 |
+
"\n",
|
1354 |
+
" with gr.Row():\n",
|
1355 |
+
" check_all_btn = gr.Button(\"✅ Check All\")\n",
|
1356 |
+
" uncheck_all_btn = gr.Button(\"❎ Uncheck All\")\n",
|
1357 |
+
"\n",
|
1358 |
+
" plot_button = gr.Button(\"▶️ Plot Forecasts\")\n",
|
1359 |
+
" errbox = gr.Textbox(label=\"Error Message\", interactive=False)\n",
|
1360 |
+
"\n",
|
1361 |
+
" # Right column: gallery + instructions\n",
|
1362 |
+
" with gr.Column():\n",
|
1363 |
+
" gr.Markdown(\"## Forecast Gallery\")\n",
|
1364 |
+
" gallery = gr.Gallery()\n",
|
1365 |
+
"\n",
|
1366 |
+
" # Instruction text below gallery\n",
|
1367 |
+
" gr.Markdown(\n",
|
1368 |
+
" \"\"\"\n",
|
1369 |
+
" **How to format your data:**\n",
|
1370 |
+
" - Your file must be a table (CSV, XLS, XLSX, or Parquet).\n",
|
1371 |
+
" - **One row per timeseries.** Each row is treated as a separate series.\n",
|
1372 |
+
" - If you want to **name** each series, put the name as the first value in **every** row:\n",
|
1373 |
+
" - Example (CSV): \n",
|
1374 |
+
" `AAPL, 120.5, 121.0, 119.8, ...` \n",
|
1375 |
+
" `AMZN, 3300.0, 3310.5, 3295.2, ...` \n",
|
1376 |
+
" - In that case, the first column is not numeric, so it will be used as the series name.\n",
|
1377 |
+
" - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:\n",
|
1378 |
+
" - Example: \n",
|
1379 |
+
" `120.5, 121.0, 119.8, ...` \n",
|
1380 |
+
" `3300.0, 3310.5, 3295.2, ...` \n",
|
1381 |
+
" - Then every row will be auto-named “Series 0, Series 1, …” in order.\n",
|
1382 |
+
" - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.\n",
|
1383 |
+
" - The rest of the columns (after the optional name) must be numeric data points for that series.\n",
|
1384 |
+
" - You can filter by typing in the search box. Then check or uncheck individual names before plotting.\n",
|
1385 |
+
" - Use “Check All” / “Uncheck All” to quickly select or deselect every series.\n",
|
1386 |
+
" - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series.\n",
|
1387 |
+
" \"\"\"\n",
|
1388 |
+
" )\n",
|
1389 |
+
"\n",
|
1390 |
+
" names_state = gr.State([])\n",
|
1391 |
+
"\n",
|
1392 |
+
" # When file or preset changes, update names\n",
|
1393 |
+
" file_input.change(\n",
|
1394 |
+
" fn=extract_names_and_update,\n",
|
1395 |
+
" inputs=[file_input, preset_dropdown],\n",
|
1396 |
+
" outputs=[filter_checkbox, names_state]\n",
|
1397 |
+
" )\n",
|
1398 |
+
" preset_dropdown.change(\n",
|
1399 |
+
" fn=extract_names_and_update,\n",
|
1400 |
+
" inputs=[file_input, preset_dropdown],\n",
|
1401 |
+
" outputs=[filter_checkbox, names_state]\n",
|
1402 |
+
" )\n",
|
1403 |
+
"\n",
|
1404 |
+
" # When search term changes, filter names\n",
|
1405 |
+
" search_box.change(\n",
|
1406 |
+
" fn=filter_names,\n",
|
1407 |
+
" inputs=[search_box, names_state],\n",
|
1408 |
+
" outputs=filter_checkbox\n",
|
1409 |
+
" )\n",
|
1410 |
+
"\n",
|
1411 |
+
" # Check All / Uncheck All\n",
|
1412 |
+
" check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)\n",
|
1413 |
+
" uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)\n",
|
1414 |
+
"\n",
|
1415 |
+
" # Plot button\n",
|
1416 |
+
" plot_button.click(\n",
|
1417 |
+
" fn=display_filtered_forecast,\n",
|
1418 |
+
" inputs=[file_input, preset_dropdown, filter_checkbox],\n",
|
1419 |
+
" outputs=[gallery, errbox]\n",
|
1420 |
+
" )\n",
|
1421 |
+
"\n",
|
1422 |
+
"demo.launch()"
|
1423 |
+
]
|
1424 |
+
},
|
1425 |
+
{
|
1426 |
+
"cell_type": "markdown",
|
1427 |
+
"metadata": {},
|
1428 |
+
"source": [
|
1429 |
+
"# Default choice - None"
|
1430 |
+
]
|
1431 |
+
},
|
1432 |
+
{
|
1433 |
+
"cell_type": "code",
|
1434 |
+
"execution_count": null,
|
1435 |
+
"metadata": {},
|
1436 |
+
"outputs": [],
|
1437 |
+
"source": [
|
1438 |
+
"import io\n",
|
1439 |
+
"import pandas as pd\n",
|
1440 |
+
"import torch\n",
|
1441 |
+
"import matplotlib.pyplot as plt\n",
|
1442 |
+
"from PIL import Image\n",
|
1443 |
+
"import numpy as np\n",
|
1444 |
+
"import gradio as gr\n",
|
1445 |
+
"\n",
|
1446 |
+
"# ----------------------------\n",
|
1447 |
+
"# Helper functions (logic unchanged)\n",
|
1448 |
+
"# ----------------------------\n",
|
1449 |
+
"\n",
|
1450 |
+
"torch.manual_seed(42)\n",
|
1451 |
+
"_forecast_tensor = torch.load(\"stocks_data_forecast.pt\") # shape = (n_series, pred_len, n_q)\n",
|
1452 |
+
"\n",
|
1453 |
+
"def model_forecast(input_data):\n",
|
1454 |
+
" return _forecast_tensor\n",
|
1455 |
+
"\n",
|
1456 |
+
"def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n",
|
1457 |
+
" fig, ax = plt.subplots(figsize=(10, 6), dpi=300)\n",
|
1458 |
+
" \n",
|
1459 |
+
" # Plot the original timeseries with thicker line and marker\n",
|
1460 |
+
" ax.plot(timeseries, color=\"blue\", linewidth=2.5, marker='o', label=\"Given Data\")\n",
|
1461 |
+
" \n",
|
1462 |
+
" x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n",
|
1463 |
+
" # Use distinct colors with higher alpha for smoothness\n",
|
1464 |
+
" for i in range(quantile_predictions.shape[1]):\n",
|
1465 |
+
" ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\", linewidth=2, alpha=0.8, label=f\"Quantile {i+1}\")\n",
|
1466 |
+
" \n",
|
1467 |
+
" ax.set_title(f\"Timeseries: {timeseries_name}\", fontsize=16, fontweight='bold')\n",
|
1468 |
+
" ax.set_xlabel(\"Time\", fontsize=12)\n",
|
1469 |
+
" ax.set_ylabel(\"Value\", fontsize=12)\n",
|
1470 |
+
" \n",
|
1471 |
+
" ax.grid(True, which='both', linestyle='--', linewidth=0.7, alpha=0.6)\n",
|
1472 |
+
" ax.legend(loc=\"center left\", bbox_to_anchor=(1, 0.5), fontsize=10, frameon=True, shadow=True)\n",
|
1473 |
+
" \n",
|
1474 |
+
" plt.tight_layout(rect=[0, 0, 0.82, 1])\n",
|
1475 |
+
" \n",
|
1476 |
+
" buf = io.BytesIO()\n",
|
1477 |
+
" fig.savefig(buf, format=\"png\", bbox_inches=\"tight\", transparent=True)\n",
|
1478 |
+
" plt.close(fig)\n",
|
1479 |
+
" buf.seek(0)\n",
|
1480 |
+
" img = Image.open(buf).convert(\"RGB\")\n",
|
1481 |
+
" return np.array(img)\n",
|
1482 |
+
"\n",
|
1483 |
+
"def load_table(file_path):\n",
|
1484 |
+
" ext = file_path.split(\".\")[-1].lower()\n",
|
1485 |
+
" if ext == \"csv\":\n",
|
1486 |
+
" return pd.read_csv(file_path)\n",
|
1487 |
+
" elif ext in (\"xls\", \"xlsx\"):\n",
|
1488 |
+
" return pd.read_excel(file_path)\n",
|
1489 |
+
" elif ext == \"parquet\":\n",
|
1490 |
+
" return pd.read_parquet(file_path)\n",
|
1491 |
+
" else:\n",
|
1492 |
+
" raise ValueError(\"Unsupported format. Use CSV, XLS, XLSX, or PARQUET.\")\n",
|
1493 |
+
"\n",
|
1494 |
+
"def extract_names_and_update(file, preset_filename):\n",
|
1495 |
+
" try:\n",
|
1496 |
+
" if file is not None:\n",
|
1497 |
+
" df = load_table(file.name)\n",
|
1498 |
+
" else:\n",
|
1499 |
+
" if not preset_filename:\n",
|
1500 |
+
" return gr.update(choices=[], value=[]), []\n",
|
1501 |
+
" df = load_table(preset_filename)\n",
|
1502 |
+
"\n",
|
1503 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n",
|
1504 |
+
" names = df.iloc[:, 0].tolist()\n",
|
1505 |
+
" else:\n",
|
1506 |
+
" names = [f\"Series {i}\" for i in range(len(df))]\n",
|
1507 |
+
" return gr.update(choices=names, value=names), names\n",
|
1508 |
+
" except Exception:\n",
|
1509 |
+
" return gr.update(choices=[], value=[]), []\n",
|
1510 |
+
"\n",
|
1511 |
+
"def filter_names(search_term, all_names):\n",
|
1512 |
+
" if not all_names:\n",
|
1513 |
+
" return gr.update(choices=[], value=[])\n",
|
1514 |
+
" if not search_term:\n",
|
1515 |
+
" return gr.update(choices=all_names, value=all_names)\n",
|
1516 |
+
" lower = search_term.lower()\n",
|
1517 |
+
" filtered = [n for n in all_names if lower in str(n).lower()]\n",
|
1518 |
+
" return gr.update(choices=filtered, value=filtered)\n",
|
1519 |
+
"\n",
|
1520 |
+
"def check_all(names_list):\n",
|
1521 |
+
" return gr.update(value=names_list)\n",
|
1522 |
+
"\n",
|
1523 |
+
"def uncheck_all(_):\n",
|
1524 |
+
" return gr.update(value=[])\n",
|
1525 |
+
"\n",
|
1526 |
+
"def display_filtered_forecast(file, preset_filename, selected_names):\n",
|
1527 |
+
" try:\n",
|
1528 |
+
" if file is not None:\n",
|
1529 |
+
" df = load_table(file.name)\n",
|
1530 |
+
" else:\n",
|
1531 |
+
" if not preset_filename:\n",
|
1532 |
+
" return [], \"No file selected.\"\n",
|
1533 |
+
" df = load_table(preset_filename)\n",
|
1534 |
+
"\n",
|
1535 |
+
" if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n",
|
1536 |
+
" all_names = df.iloc[:, 0].tolist()\n",
|
1537 |
+
" data_only = df.iloc[:, 1:].astype(float)\n",
|
1538 |
+
" else:\n",
|
1539 |
+
" all_names = [f\"Series {i}\" for i in range(len(df))]\n",
|
1540 |
+
" data_only = df.astype(float)\n",
|
1541 |
+
"\n",
|
1542 |
+
" mask = [name in selected_names for name in all_names]\n",
|
1543 |
+
" if not any(mask):\n",
|
1544 |
+
" return [], \"No timeseries chosen to plot.\"\n",
|
1545 |
+
"\n",
|
1546 |
+
" filtered_data = data_only.iloc[mask, :].values\n",
|
1547 |
+
" filtered_names = [all_names[i] for i, m in enumerate(mask) if m]\n",
|
1548 |
+
" out = _forecast_tensor[mask] # slice forecasts to match filtered rows\n",
|
1549 |
+
" inp = torch.tensor(filtered_data)\n",
|
1550 |
+
"\n",
|
1551 |
+
" gallery_images = []\n",
|
1552 |
+
" for i in range(inp.shape[0]):\n",
|
1553 |
+
" gallery_images.append(plot_forecast_image(inp[i], out[i], filtered_names[i]))\n",
|
1554 |
+
"\n",
|
1555 |
+
" return gallery_images, \"\"\n",
|
1556 |
+
" except Exception as e:\n",
|
1557 |
+
" return [], f\"Error: {e}. Use CSV, XLS, XLSX, or PARQUET.\"\n",
|
1558 |
+
"\n",
|
1559 |
+
"\n",
|
1560 |
+
"# ----------------------------\n",
|
1561 |
+
"# Gradio layout: two columns + instructions\n",
|
1562 |
+
"# ----------------------------\n",
|
1563 |
+
"\n",
|
1564 |
+
"with gr.Blocks() as demo:\n",
|
1565 |
+
" gr.Markdown(\"# 📈 Stock Forecast Viewer 📊\")\n",
|
1566 |
+
" gr.Markdown(\"Upload data or choose a preset, filter by name, then click Plot.\")\n",
|
1567 |
+
"\n",
|
1568 |
+
" with gr.Row():\n",
|
1569 |
+
" # Left column: controls\n",
|
1570 |
+
" with gr.Column():\n",
|
1571 |
+
" gr.Markdown(\"## Data Selection\")\n",
|
1572 |
+
" file_input = gr.File(\n",
|
1573 |
+
" label=\"Upload CSV / XLSX / PARQUET\",\n",
|
1574 |
+
" file_types=[\".csv\", \".xls\", \".xlsx\", \".parquet\"]\n",
|
1575 |
+
" )\n",
|
1576 |
+
" preset_dropdown = gr.Dropdown(\n",
|
1577 |
+
" label=\"Or choose a preset:\",\n",
|
1578 |
+
" choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n",
|
1579 |
+
" value=\"No file selected\"\n",
|
1580 |
+
" )\n",
|
1581 |
+
"\n",
|
1582 |
+
" gr.Markdown(\"## Search / Filter\")\n",
|
1583 |
+
" search_box = gr.Textbox(placeholder=\"Type to filter (e.g. 'AMZN')\")\n",
|
1584 |
+
" filter_checkbox = gr.CheckboxGroup(\n",
|
1585 |
+
" choices=[], value=[], label=\"Select which timeseries to show\"\n",
|
1586 |
+
" )\n",
|
1587 |
+
"\n",
|
1588 |
+
" with gr.Row():\n",
|
1589 |
+
" check_all_btn = gr.Button(\"✅ Check All\")\n",
|
1590 |
+
" uncheck_all_btn = gr.Button(\"❎ Uncheck All\")\n",
|
1591 |
+
"\n",
|
1592 |
+
" plot_button = gr.Button(\"▶️ Plot Forecasts\")\n",
|
1593 |
+
" errbox = gr.Textbox(label=\"Error Message\", interactive=False)\n",
|
1594 |
+
"\n",
|
1595 |
+
" # Right column: gallery + instructions\n",
|
1596 |
+
" with gr.Column():\n",
|
1597 |
+
" gr.Markdown(\"## Forecast Gallery\")\n",
|
1598 |
+
" gallery = gr.Gallery()\n",
|
1599 |
+
"\n",
|
1600 |
+
" # Instruction text below gallery\n",
|
1601 |
+
" gr.Markdown(\"## Instructions\")\n",
|
1602 |
+
" gr.Markdown(\n",
|
1603 |
+
" \"\"\"\n",
|
1604 |
+
" **How to format your data:**\n",
|
1605 |
+
" - Your file must be a table (CSV, XLS, XLSX, or Parquet).\n",
|
1606 |
+
" - **One row per timeseries.** Each row is treated as a separate series.\n",
|
1607 |
+
" - If you want to **name** each series, put the name as the first value in **every** row:\n",
|
1608 |
+
" - Example (CSV): \n",
|
1609 |
+
" `AAPL, 120.5, 121.0, 119.8, ...` \n",
|
1610 |
+
" `AMZN, 3300.0, 3310.5, 3295.2, ...` \n",
|
1611 |
+
" - In that case, the first column is not numeric, so it will be used as the series name.\n",
|
1612 |
+
" - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:\n",
|
1613 |
+
" - Example: \n",
|
1614 |
+
" `120.5, 121.0, 119.8, ...` \n",
|
1615 |
+
" `3300.0, 3310.5, 3295.2, ...` \n",
|
1616 |
+
" - Then every row will be auto-named “Series 0, Series 1, …” in order.\n",
|
1617 |
+
" - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.\n",
|
1618 |
+
" - The rest of the columns (after the optional name) must be numeric data points for that series.\n",
|
1619 |
+
" - You can filter by typing in the search box. Then check or uncheck individual names before plotting.\n",
|
1620 |
+
" - Use “Check All” / “Uncheck All” to quickly select or deselect every series.\n",
|
1621 |
+
" - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series.\n",
|
1622 |
+
" \"\"\"\n",
|
1623 |
+
" )\n",
|
1624 |
+
"\n",
|
1625 |
+
" names_state = gr.State([])\n",
|
1626 |
+
"\n",
|
1627 |
+
" # When file or preset changes, update names\n",
|
1628 |
+
" file_input.change(\n",
|
1629 |
+
" fn=extract_names_and_update,\n",
|
1630 |
+
" inputs=[file_input, preset_dropdown],\n",
|
1631 |
+
" outputs=[filter_checkbox, names_state]\n",
|
1632 |
+
" )\n",
|
1633 |
+
" preset_dropdown.change(\n",
|
1634 |
+
" fn=extract_names_and_update,\n",
|
1635 |
+
" inputs=[file_input, preset_dropdown],\n",
|
1636 |
+
" outputs=[filter_checkbox, names_state]\n",
|
1637 |
+
" )\n",
|
1638 |
+
"\n",
|
1639 |
+
" # When search term changes, filter names\n",
|
1640 |
+
" search_box.change(\n",
|
1641 |
+
" fn=filter_names,\n",
|
1642 |
+
" inputs=[search_box, names_state],\n",
|
1643 |
+
" outputs=filter_checkbox\n",
|
1644 |
+
" )\n",
|
1645 |
+
"\n",
|
1646 |
+
" # Check All / Uncheck All\n",
|
1647 |
+
" check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)\n",
|
1648 |
+
" uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)\n",
|
1649 |
+
"\n",
|
1650 |
+
" # Plot button\n",
|
1651 |
+
" plot_button.click(\n",
|
1652 |
+
" fn=display_filtered_forecast,\n",
|
1653 |
+
" inputs=[file_input, preset_dropdown, filter_checkbox],\n",
|
1654 |
+
" outputs=[gallery, errbox]\n",
|
1655 |
+
" )\n",
|
1656 |
+
"\n",
|
1657 |
+
"demo.launch()"
|
1658 |
+
]
|
1659 |
+
}
|
1660 |
+
],
|
1661 |
+
"metadata": {
|
1662 |
+
"kernelspec": {
|
1663 |
+
"display_name": "Python 3",
|
1664 |
+
"language": "python",
|
1665 |
+
"name": "python3"
|
1666 |
+
},
|
1667 |
+
"language_info": {
|
1668 |
+
"codemirror_mode": {
|
1669 |
+
"name": "ipython",
|
1670 |
+
"version": 3
|
1671 |
+
},
|
1672 |
+
"file_extension": ".py",
|
1673 |
+
"mimetype": "text/x-python",
|
1674 |
+
"name": "python",
|
1675 |
+
"nbconvert_exporter": "python",
|
1676 |
+
"pygments_lexer": "ipython3",
|
1677 |
+
"version": "3.11.11"
|
1678 |
+
}
|
1679 |
+
},
|
1680 |
+
"nbformat": 4,
|
1681 |
+
"nbformat_minor": 2
|
1682 |
+
}
|
tirex
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 341adb31e1f727d181682993db223bf249d9aa9a
|