Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,172 +1,316 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
|
4 |
-
from
|
|
|
|
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
|
|
7 |
import os
|
|
|
8 |
import warnings
|
|
|
9 |
|
10 |
-
# Suppress warnings for a cleaner
|
11 |
warnings.filterwarnings("ignore")
|
|
|
|
|
12 |
|
13 |
-
# ---
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
ALL_CHANNELS = [f"AIA {ch} Å" for ch in AIA_CHANNELS] + [f"HMI {ch}" for ch in HMI_CHANNELS]
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
@gr.cache
|
21 |
-
def
|
22 |
"""
|
23 |
-
Downloads
|
24 |
-
|
|
|
25 |
"""
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
os.makedirs(model_dir, exist_ok=True)
|
31 |
-
os.makedirs(data_dir, exist_ok=True)
|
32 |
-
|
33 |
-
# Download the model weights and test data from Hugging Face
|
34 |
-
checkpoint_path = hf_hub_download(
|
35 |
repo_id="nasa-ibm-ai4science/Surya-1.0",
|
36 |
-
|
37 |
-
|
38 |
)
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
model = Surya(
|
48 |
-
img_size=4096,
|
49 |
-
patch_size=16,
|
50 |
-
in_chans=13,
|
51 |
-
embed_dim=1280,
|
52 |
-
spectral_blocks=2,
|
53 |
-
attention_blocks=8,
|
54 |
)
|
55 |
|
56 |
-
#
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
test_label = test_data["label"] # Ground truth for comparison
|
66 |
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
"""
|
72 |
-
|
|
|
73 |
"""
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
|
87 |
-
def run_forecast(channel_name, progress=gr.Progress()):
|
88 |
"""
|
89 |
-
|
90 |
-
|
91 |
"""
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
|
97 |
-
#
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
104 |
|
105 |
-
|
106 |
-
# Shape: [batch, channels, time, height, width] -> select channel, last time step
|
107 |
-
input_slice = test_input[0, channel_index, -1, :, :]
|
108 |
-
input_image = tensor_to_image(input_slice)
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
label_image = tensor_to_image(label_slice)
|
118 |
|
119 |
-
|
120 |
-
|
121 |
|
122 |
-
# --- 5.
|
123 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
|
|
|
|
|
|
|
|
|
124 |
gr.Markdown(
|
125 |
"""
|
126 |
<div align="center">
|
127 |
-
# ☀️ Surya:
|
128 |
-
|
129 |
-
|
|
|
130 |
</div>
|
131 |
"""
|
132 |
)
|
133 |
-
|
134 |
with gr.Row():
|
135 |
channel_selector = gr.Dropdown(
|
136 |
-
choices=
|
137 |
-
value=
|
138 |
label="🛰️ Select SDO Instrument Channel",
|
139 |
info="Choose which solar observation channel to visualize."
|
140 |
)
|
|
|
141 |
|
142 |
-
|
143 |
-
|
144 |
with gr.Row():
|
145 |
with gr.Column():
|
146 |
gr.Markdown("### ⬅️ Final Input Image")
|
147 |
-
gr.Markdown("The last
|
148 |
-
input_display = gr.Image(label="Input
|
149 |
with gr.Column():
|
150 |
gr.Markdown("### 🔮 Model's Forecast")
|
151 |
-
gr.Markdown("
|
152 |
-
prediction_display = gr.Image(label="
|
153 |
with gr.Column():
|
154 |
gr.Markdown("### ✅ Ground Truth")
|
155 |
-
gr.Markdown("What the Sun *actually* looked like at
|
156 |
-
label_display = gr.Image(label="
|
157 |
-
|
158 |
-
gr.Markdown(
|
159 |
-
"--- \n"
|
160 |
-
"**Note:** The first time you run a forecast, the app will download the 366M-parameter model (~1.4 GB) and test data. Subsequent runs will be much faster. "
|
161 |
-
"The images are downscaled for display in this demo. "
|
162 |
-
"For more information, visit the [Surya Hugging Face Repository](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
|
163 |
-
)
|
164 |
|
|
|
165 |
run_button.click(
|
166 |
-
fn=
|
167 |
inputs=[channel_selector],
|
168 |
-
outputs=[input_display, prediction_display, label_display]
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
)
|
170 |
|
171 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
172 |
demo.launch(debug=True)
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
import yaml
|
7 |
import numpy as np
|
8 |
from PIL import Image
|
9 |
+
import sunpy.visualization.colormaps as sunpy_cm
|
10 |
import os
|
11 |
+
import glob
|
12 |
import warnings
|
13 |
+
import logging
|
14 |
|
15 |
+
# --- Suppress verbose logging and warnings for a cleaner UI ---
|
16 |
warnings.filterwarnings("ignore")
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
|
20 |
+
# --- Dependencies from the Surya Repository ---
|
21 |
+
# NOTE: To make this script self-contained, the required classes and functions
|
22 |
+
# from the 'surya' library are included directly here.
|
23 |
+
# In a full installation, these would be imported.
|
|
|
24 |
|
25 |
+
from surya_dependencies import (
|
26 |
+
HelioSpectFormer,
|
27 |
+
HelioNetCDFDataset,
|
28 |
+
build_scalers,
|
29 |
+
custom_collate_fn,
|
30 |
+
inverse_transform_single_channel,
|
31 |
+
SDO_CHANNELS,
|
32 |
+
AIA_CHANNELS,
|
33 |
+
HMI_CHANNELS
|
34 |
+
)
|
35 |
+
|
36 |
+
# --- Global Cache for Model and Data ---
|
37 |
+
# We use a simple dictionary to act as a cache to avoid reloading.
|
38 |
+
APP_CACHE = {
|
39 |
+
"model": None,
|
40 |
+
"config": None,
|
41 |
+
"scalers": None,
|
42 |
+
"dataset": None,
|
43 |
+
"dataloader": None,
|
44 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
45 |
+
}
|
46 |
+
|
47 |
+
# --- 1. Setup and Data Download ---
|
48 |
@gr.cache
|
49 |
+
def setup_environment_and_download_data():
|
50 |
"""
|
51 |
+
Downloads all necessary files from Hugging Face: model, config, scalers, and validation data.
|
52 |
+
Also creates the necessary index file for the dataset loader.
|
53 |
+
This function is cached by Gradio to run only once.
|
54 |
"""
|
55 |
+
logger.info("Setting up environment. This will run only once.")
|
56 |
+
local_dir = "data/Surya-1.0"
|
57 |
+
# Download model, config, and scalers
|
58 |
+
snapshot_download(
|
|
|
|
|
|
|
|
|
|
|
59 |
repo_id="nasa-ibm-ai4science/Surya-1.0",
|
60 |
+
local_dir=local_dir,
|
61 |
+
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
|
62 |
)
|
63 |
+
|
64 |
+
# Download validation data
|
65 |
+
data_dir = "data/Surya-1.0_validation_data"
|
66 |
+
snapshot_download(
|
67 |
+
repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
|
68 |
+
repo_type="dataset",
|
69 |
+
local_dir=data_dir,
|
70 |
+
allow_patterns="20140107_1[5-9]??.nc",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
)
|
72 |
|
73 |
+
# The test script requires an index file. We'll create it dynamically.
|
74 |
+
index_dir = "data/test_indices"
|
75 |
+
os.makedirs(index_dir, exist_ok=True)
|
76 |
+
index_file_path = os.path.join(index_dir, "test_surya_index.csv")
|
77 |
+
|
78 |
+
with open(index_file_path, "w") as f:
|
79 |
+
f.write("path\n")
|
80 |
+
# Find the downloaded NetCDF files and write their paths to the index
|
81 |
+
search_path = os.path.join(data_dir, "**", "*.nc")
|
82 |
+
for nc_file in sorted(glob.glob(search_path, recursive=True)):
|
83 |
+
f.write(f"{nc_file}\n")
|
84 |
+
logger.info(f"Created index file at {index_file_path}")
|
85 |
+
return index_file_path, local_dir
|
86 |
+
|
87 |
+
# --- 2. Model and Data Loading ---
|
88 |
+
def load_essentials(model_dir):
|
89 |
+
"""Loads config, scalers, and the model into the APP_CACHE."""
|
90 |
+
if APP_CACHE["model"] is None:
|
91 |
+
logger.info("Loading config, scalers, and model for the first time...")
|
92 |
+
# Load config
|
93 |
+
with open(os.path.join(model_dir, "config.yaml")) as fp:
|
94 |
+
config = yaml.safe_load(fp)
|
95 |
+
APP_CACHE["config"] = config
|
96 |
|
97 |
+
# Build scalers for data normalization
|
98 |
+
scalers_info = yaml.safe_load(open(os.path.join(model_dir, "scalers.yaml"), "r"))
|
99 |
+
APP_CACHE["scalers"] = build_scalers(info=scalers_info)
|
|
|
100 |
|
101 |
+
# Initialize model from config
|
102 |
+
model = HelioSpectFormer(
|
103 |
+
img_size=config["model"]["img_size"],
|
104 |
+
patch_size=config["model"]["patch_size"],
|
105 |
+
in_chans=len(config["data"]["sdo_channels"]),
|
106 |
+
embed_dim=config["model"]["embed_dim"],
|
107 |
+
time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
|
108 |
+
depth=config["model"]["depth"],
|
109 |
+
n_spectral_blocks=config["model"]["n_spectral_blocks"],
|
110 |
+
num_heads=config["model"]["num_heads"],
|
111 |
+
mlp_ratio=config["model"]["mlp_ratio"],
|
112 |
+
drop_rate=config["model"]["drop_rate"],
|
113 |
+
dtype=torch.bfloat16,
|
114 |
+
window_size=config["model"]["window_size"],
|
115 |
+
dp_rank=config["model"]["dp_rank"],
|
116 |
+
learned_flow=config["model"]["learned_flow"],
|
117 |
+
use_latitude_in_learned_flow=config["model"]["learned_flow"],
|
118 |
+
init_weights=False,
|
119 |
+
checkpoint_layers=[i for i in range(config["model"]["depth"])],
|
120 |
+
rpe=config["model"]["rpe"],
|
121 |
+
ensemble=config["model"]["ensemble"],
|
122 |
+
finetune=config["model"]["finetune"],
|
123 |
+
)
|
124 |
+
|
125 |
+
# Load pre-trained weights
|
126 |
+
path_weights = os.path.join(model_dir, "surya.366m.v1.pt")
|
127 |
+
weights = torch.load(path_weights, map_location=torch.device(APP_CACHE["device"]))
|
128 |
+
model.load_state_dict(weights, strict=True)
|
129 |
+
model.to(APP_CACHE["device"])
|
130 |
+
model.eval()
|
131 |
+
|
132 |
+
n_params = sum(p.numel() for p in model.parameters()) / 1e6
|
133 |
+
logger.info(f"Surya FM: {n_params:.2f}M parameters loaded to {APP_CACHE['device']}.")
|
134 |
+
APP_CACHE["model"] = model
|
135 |
|
136 |
+
def get_dataloader(index_path):
|
137 |
+
"""Initializes and returns a DataLoader for the validation data."""
|
138 |
+
if APP_CACHE["dataloader"] is None:
|
139 |
+
logger.info("Initializing dataset and dataloader...")
|
140 |
+
config = APP_CACHE["config"]
|
141 |
+
dataset = HelioNetCDFDataset(
|
142 |
+
index_path=index_path,
|
143 |
+
time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
|
144 |
+
time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
|
145 |
+
n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
|
146 |
+
rollout_steps=1,
|
147 |
+
channels=config["data"]["sdo_channels"],
|
148 |
+
scalers=APP_CACHE["scalers"],
|
149 |
+
phase="valid", # Important: ensure no random augmentations
|
150 |
+
)
|
151 |
+
dataloader = DataLoader(
|
152 |
+
dataset, shuffle=False, batch_size=1, num_workers=2,
|
153 |
+
pin_memory=True, drop_last=False, collate_fn=custom_collate_fn,
|
154 |
+
)
|
155 |
+
APP_CACHE["dataloader"] = dataloader
|
156 |
+
APP_CACHE["dataset"] = dataset # Also cache dataset for transformation info
|
157 |
+
return APP_CACHE["dataloader"]
|
158 |
+
|
159 |
+
|
160 |
+
# --- 3. Core Inference and Visualization Logic ---
|
161 |
+
def run_model_inference():
|
162 |
"""
|
163 |
+
Performs a single prediction step using the loaded model and dataloader.
|
164 |
+
Returns the raw input, prediction, and ground truth tensors.
|
165 |
"""
|
166 |
+
model = APP_CACHE["model"]
|
167 |
+
dataloader = APP_CACHE["dataloader"]
|
168 |
+
device = APP_CACHE["device"]
|
169 |
+
|
170 |
+
# Get the first (and only) batch of data from the validation set
|
171 |
+
batch_data, batch_metadata = next(iter(dataloader))
|
172 |
|
173 |
+
logger.info("Running inference on the validation batch...")
|
174 |
+
with torch.no_grad():
|
175 |
+
# Prepare input batch for the model
|
176 |
+
input_batch = {key: batch_data[key].to(device) for key in ["ts", "time_delta_input"]}
|
177 |
+
# Run model prediction
|
178 |
+
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
|
179 |
+
prediction_tensor = model(input_batch)
|
180 |
|
181 |
+
# Get the input and target tensors for comparison
|
182 |
+
input_tensor = input_batch["ts"].to(dtype=torch.float32).cpu()
|
183 |
+
target_tensor = batch_data["forecast"].cpu()
|
184 |
+
prediction_tensor = prediction_tensor.to(dtype=torch.float32).cpu()
|
185 |
+
|
186 |
+
logger.info("Inference complete.")
|
187 |
+
return input_tensor, prediction_tensor, target_tensor
|
188 |
|
189 |
+
def create_visualizations(channel_name, input_tensor, prediction_tensor, target_tensor):
|
|
|
190 |
"""
|
191 |
+
Takes raw tensors and a channel name, applies inverse transformation,
|
192 |
+
and converts them to displayable PIL Images.
|
193 |
"""
|
194 |
+
if input_tensor is None:
|
195 |
+
return None, None, None, "Please run the forecast first."
|
196 |
+
|
197 |
+
logger.info(f"Creating visualization for channel: {channel_name}")
|
198 |
+
c_idx = SDO_CHANNELS.index(channel_name)
|
199 |
+
dataset = APP_CACHE["dataset"]
|
200 |
+
means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs()
|
201 |
|
202 |
+
# --- Denormalize data for visualization ---
|
203 |
+
# Final input image given to the model (last in sequence)
|
204 |
+
input_slice = inverse_transform_single_channel(
|
205 |
+
input_tensor[0, c_idx, -1, :, :].numpy(),
|
206 |
+
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
207 |
+
)
|
208 |
+
# Model's prediction
|
209 |
+
pred_slice = inverse_transform_single_channel(
|
210 |
+
prediction_tensor[0, c_idx, :, :].numpy(),
|
211 |
+
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
212 |
+
)
|
213 |
+
# Ground truth image
|
214 |
+
target_slice = inverse_transform_single_channel(
|
215 |
+
target_tensor[0, c_idx, 0, :, :].numpy(),
|
216 |
+
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
217 |
+
)
|
218 |
+
|
219 |
+
# --- Convert to images ---
|
220 |
+
# Use a shared color scale for better comparison, clipped at 99.5th percentile
|
221 |
+
vmax = np.quantile(np.concatenate([input_slice, pred_slice, target_slice]), 0.995)
|
222 |
+
|
223 |
+
# Determine colormap from channel name
|
224 |
+
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
225 |
+
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
226 |
|
227 |
+
def to_pil(data, vmin=0, vmax=vmax, cmap=cmap):
|
228 |
+
data_clipped = np.clip(data, vmin, vmax)
|
229 |
+
data_norm = (data_clipped - vmin) / (vmax - vmin)
|
230 |
+
return Image.fromarray((cmap(data_norm)[:, :, :3] * 255).astype(np.uint8))
|
231 |
|
232 |
+
return to_pil(input_slice), to_pil(pred_slice), to_pil(target_slice), f"Displaying forecast for {channel_name}"
|
|
|
|
|
|
|
233 |
|
234 |
+
# --- 4. Gradio Controller Functions ---
|
235 |
+
def forecast_controller(channel_name, progress=gr.Progress()):
|
236 |
+
"""
|
237 |
+
Main function triggered by the 'Generate' button. Orchestrates the entire pipeline.
|
238 |
+
"""
|
239 |
+
progress(0, desc="Downloading model and data (first launch only)...")
|
240 |
+
index_path, model_dir = setup_environment_and_download_data()
|
241 |
+
|
242 |
+
progress(0.4, desc="Loading model and building data pipeline...")
|
243 |
+
load_essentials(model_dir)
|
244 |
+
get_dataloader(index_path)
|
245 |
+
|
246 |
+
progress(0.7, desc=f"Running inference on {APP_CACHE['device']}...")
|
247 |
+
input_t, pred_t, target_t = run_model_inference()
|
248 |
|
249 |
+
progress(0.9, desc="Creating visualizations...")
|
250 |
+
img_in, img_pred, img_target, status = create_visualizations(channel_name, input_t, pred_t, target_t)
|
|
|
251 |
|
252 |
+
return img_in, img_pred, img_target, status, input_t, pred_t, target_t
|
253 |
+
|
254 |
|
255 |
+
# --- 5. Gradio UI Layout ---
|
256 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
257 |
+
# Hidden state variables to store the raw tensors after inference
|
258 |
+
state_input = gr.State()
|
259 |
+
state_prediction = gr.State()
|
260 |
+
state_target = gr.State()
|
261 |
+
|
262 |
gr.Markdown(
|
263 |
"""
|
264 |
<div align="center">
|
265 |
+
# ☀️ Surya: Live Model Demo ☀️
|
266 |
+
### An Interactive Interface for NASA's Heliophysics Foundation Model
|
267 |
+
This demo runs the **actual** Surya model on its official validation data for **2014-01-07**.
|
268 |
+
Click the button to generate a forecast, then use the dropdown to explore the results across different SDO instrument channels.
|
269 |
</div>
|
270 |
"""
|
271 |
)
|
272 |
+
|
273 |
with gr.Row():
|
274 |
channel_selector = gr.Dropdown(
|
275 |
+
choices=SDO_CHANNELS,
|
276 |
+
value="aia171",
|
277 |
label="🛰️ Select SDO Instrument Channel",
|
278 |
info="Choose which solar observation channel to visualize."
|
279 |
)
|
280 |
+
run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary", scale=2)
|
281 |
|
282 |
+
status_box = gr.Textbox(label="Status", interactive=False, value="Ready. Press 'Generate Forecast' to start.")
|
283 |
+
|
284 |
with gr.Row():
|
285 |
with gr.Column():
|
286 |
gr.Markdown("### ⬅️ Final Input Image")
|
287 |
+
gr.Markdown("The last observation shown to the model (T-1).")
|
288 |
+
input_display = gr.Image(label="Input", height=512, width=512, interactive=False)
|
289 |
with gr.Column():
|
290 |
gr.Markdown("### 🔮 Model's Forecast")
|
291 |
+
gr.Markdown("Surya's prediction for the next timestep (T+0).")
|
292 |
+
prediction_display = gr.Image(label="Prediction", height=512, width=512, interactive=False)
|
293 |
with gr.Column():
|
294 |
gr.Markdown("### ✅ Ground Truth")
|
295 |
+
gr.Markdown("What the Sun *actually* looked like at T+0.")
|
296 |
+
label_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
+
# --- Event Handlers ---
|
299 |
run_button.click(
|
300 |
+
fn=forecast_controller,
|
301 |
inputs=[channel_selector],
|
302 |
+
outputs=[input_display, prediction_display, label_display, status_box, state_input, state_prediction, state_target]
|
303 |
+
)
|
304 |
+
|
305 |
+
channel_selector.change(
|
306 |
+
fn=create_visualizations,
|
307 |
+
inputs=[channel_selector, state_input, state_prediction, state_target],
|
308 |
+
outputs=[input_display, prediction_display, label_display, status_box]
|
309 |
)
|
310 |
|
311 |
if __name__ == "__main__":
|
312 |
+
# The 'surya_dependencies.py' file must be in the same directory as this script.
|
313 |
+
# Create the placeholder file if it doesn't exist.
|
314 |
+
if not os.path.exists("surya_dependencies.py"):
|
315 |
+
raise FileNotFoundError("The required 'surya_dependencies.py' file is missing. Please download it from the provided source.")
|
316 |
demo.launch(debug=True)
|