cleanup: update files; feat: add prediction across all models
Browse filesSigned-off-by: K.B.Dharun Krishna <[email protected]>
- .gitignore +174 -0
- app.py +430 -94
- requirements.txt +1 -0
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# Ruff stuff:
|
171 |
+
.ruff_cache/
|
172 |
+
|
173 |
+
# PyPI configuration file
|
174 |
+
.pypirc
|
app.py
CHANGED
@@ -1,8 +1,14 @@
|
|
1 |
import streamlit as st
|
2 |
import numpy as np
|
|
|
3 |
from PIL import Image
|
4 |
import cv2
|
5 |
from scipy.ndimage import gaussian_filter
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# ------------------ TC CENTERING UTILS ------------------
|
8 |
|
@@ -11,6 +17,99 @@ def find_tc_center(ir_image, smoothing_sigma=3):
|
|
11 |
min_coords = np.unravel_index(np.argmin(smoothed_image), smoothed_image.shape)
|
12 |
return min_coords[::-1] # Return as (x, y)
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def extract_local_region(ir_image, center, region_size=95):
|
15 |
h, w = ir_image.shape
|
16 |
half_size = region_size // 2
|
@@ -46,10 +145,9 @@ def create_3d_vmax(vmax_2d_array):
|
|
46 |
for i in range(vmax_2d_array.shape[0]):
|
47 |
np.fill_diagonal(vmax_3d_array[i], vmax_2d_array[i])
|
48 |
|
49 |
-
# Reshape to (N*
|
50 |
vmax_3d_array = vmax_3d_array.reshape(-1, 8, 8, 1)
|
51 |
-
|
52 |
-
|
53 |
return vmax_3d_array
|
54 |
|
55 |
def process_lat_values(data):
|
@@ -173,9 +271,44 @@ def compute_convective_core_masks(ir_data):
|
|
173 |
|
174 |
|
175 |
# ------------------ Streamlit UI ------------------
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
ir_images = st.file_uploader("Upload 8 IR images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
|
181 |
pmw_images = st.file_uploader("Upload 8 PMW images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
|
@@ -210,7 +343,12 @@ if csv_file is not None:
|
|
210 |
vmax_values = np.array(vmax_values)
|
211 |
|
212 |
st.success("CSV file loaded and processed successfully!")
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
else:
|
216 |
st.error("CSV file must contain 'Latitude', 'Longitude', and 'Vmax' columns.")
|
@@ -218,84 +356,206 @@ if csv_file is not None:
|
|
218 |
st.error(f"Error reading CSV: {e}")
|
219 |
else:
|
220 |
st.warning("Please upload a CSV file.")
|
221 |
-
st.header("Select Prediction Model")
|
222 |
-
model_choice = st.selectbox(
|
223 |
-
"Choose a model for prediction",
|
224 |
-
("ConvGRU", "ConvLSTM", "Traj-GRU","3DCNN","spatiotemporalLSTM","Unet_LSTM"),
|
225 |
-
index=0
|
226 |
-
)
|
227 |
-
# ------------------ Process Button ------------------
|
228 |
-
if st.button("Submit for Processing"):
|
229 |
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
import tensorflow as tf
|
298 |
-
|
299 |
def tf_gradient_magnitude(images):
|
300 |
# Sobel kernels
|
301 |
sobel_x = tf.constant([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=tf.float32)
|
@@ -311,6 +571,7 @@ if st.button("Submit for Processing"):
|
|
311 |
grad_mag = tf.sqrt(tf.square(gx) + tf.square(gy) + 1e-6)
|
312 |
|
313 |
return tf.squeeze(grad_mag, -1).numpy()
|
|
|
314 |
def GM_maps_prep(ir):
|
315 |
GM_maps=[]
|
316 |
for i in ir:
|
@@ -318,17 +579,92 @@ if st.button("Submit for Processing"):
|
|
318 |
GM_maps.append(GM_map)
|
319 |
GM_maps=np.array(GM_maps)
|
320 |
return GM_maps
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
GM_maps=
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
else:
|
331 |
-
|
332 |
-
st.write("Predicted Vmax:", y)
|
333 |
-
else:
|
334 |
-
st.error("Make sure you uploaded exactly 8 IR and 8 PMW images.")
|
|
|
1 |
import streamlit as st
|
2 |
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
from PIL import Image
|
5 |
import cv2
|
6 |
from scipy.ndimage import gaussian_filter
|
7 |
+
import tensorflow as tf
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import io
|
10 |
+
from matplotlib.figure import Figure
|
11 |
+
import base64
|
12 |
|
13 |
# ------------------ TC CENTERING UTILS ------------------
|
14 |
|
|
|
17 |
min_coords = np.unravel_index(np.argmin(smoothed_image), smoothed_image.shape)
|
18 |
return min_coords[::-1] # Return as (x, y)
|
19 |
|
20 |
+
# Function to generate comparison chart
|
21 |
+
def generate_comparison_chart(models, mae_values, rmse_values, predicted_values=None):
|
22 |
+
# Calculate improvement percentages relative to the first model
|
23 |
+
baseline_mae = mae_values[0]
|
24 |
+
baseline_rmse = rmse_values[0]
|
25 |
+
|
26 |
+
mae_improvements = [0] + [((baseline_mae - val) / baseline_mae) * 100 for val in mae_values[1:]]
|
27 |
+
rmse_improvements = [0] + [((baseline_rmse - val) / baseline_rmse) * 100 for val in rmse_values[1:]]
|
28 |
+
|
29 |
+
# Create figure with subplots (2 or 3 depending on if we have predictions)
|
30 |
+
if predicted_values:
|
31 |
+
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 8))
|
32 |
+
else:
|
33 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
|
34 |
+
|
35 |
+
# Plot MAE
|
36 |
+
bars1 = ax1.bar(range(len(models)), mae_values, color='skyblue', edgecolor='black')
|
37 |
+
ax1.set_title('Mean Absolute Error (MAE)', fontsize=14, fontweight='bold')
|
38 |
+
ax1.set_ylabel('MAE (knots)', fontsize=12)
|
39 |
+
ax1.set_xticks(range(len(models)))
|
40 |
+
ax1.set_xticklabels(models, fontsize=12, rotation=45, ha='right')
|
41 |
+
ax1.grid(axis='y', linestyle='--', alpha=0.3, color='lightgray')
|
42 |
+
ax1.set_ylim(0, max(mae_values) * 1.2)
|
43 |
+
|
44 |
+
# Plot RMSE
|
45 |
+
bars2 = ax2.bar(range(len(models)), rmse_values, color='lightcoral', edgecolor='black')
|
46 |
+
ax2.set_title('Root Mean Square Error (RMSE)', fontsize=14, fontweight='bold')
|
47 |
+
ax2.set_ylabel('RMSE (knots)', fontsize=12)
|
48 |
+
ax2.set_xticks(range(len(models)))
|
49 |
+
ax2.set_xticklabels(models, fontsize=12, rotation=45, ha='right')
|
50 |
+
ax2.grid(axis='y', linestyle='--', alpha=0.3, color='lightgray')
|
51 |
+
ax2.set_ylim(0, max(rmse_values) * 1.2)
|
52 |
+
|
53 |
+
# Add values on top of the bars for MAE
|
54 |
+
for i, bar in enumerate(bars1):
|
55 |
+
height = bar.get_height()
|
56 |
+
ax1.text(bar.get_x() + bar.get_width()/2., height + 0.3,
|
57 |
+
f'{height:.2f}', ha='center', va='bottom', fontsize=12)
|
58 |
+
|
59 |
+
# Add improvement percentage for all except the first bar
|
60 |
+
if i > 0:
|
61 |
+
ax1.text(bar.get_x() + bar.get_width()/2., height/2,
|
62 |
+
f'↓{mae_improvements[i]:.1f}%', ha='center', va='center',
|
63 |
+
color='blue', fontsize=12, fontweight='bold')
|
64 |
+
|
65 |
+
# Add values on top of the bars for RMSE
|
66 |
+
for i, bar in enumerate(bars2):
|
67 |
+
height = bar.get_height()
|
68 |
+
ax2.text(bar.get_x() + bar.get_width()/2., height + 0.3,
|
69 |
+
f'{height:.2f}', ha='center', va='bottom', fontsize=12)
|
70 |
+
|
71 |
+
# Add improvement percentage for all except the first bar
|
72 |
+
if i > 0:
|
73 |
+
ax2.text(bar.get_x() + bar.get_width()/2., height/2,
|
74 |
+
f'↓{rmse_improvements[i]:.1f}%', ha='center', va='center',
|
75 |
+
color='darkred', fontsize=12, fontweight='bold')
|
76 |
+
|
77 |
+
# Add horizontal reference lines for best performance
|
78 |
+
min_mae = min(mae_values)
|
79 |
+
min_rmse = min(rmse_values)
|
80 |
+
ax1.axhline(y=min_mae, color='blue', linestyle='--', alpha=0.5)
|
81 |
+
ax2.axhline(y=min_rmse, color='red', linestyle='--', alpha=0.5)
|
82 |
+
|
83 |
+
# Add predictions comparison if provided
|
84 |
+
if predicted_values:
|
85 |
+
bars3 = ax3.bar(range(len(models)), predicted_values, color='lightgreen', edgecolor='black')
|
86 |
+
ax3.set_title('Predicted Vmax', fontsize=14, fontweight='bold')
|
87 |
+
ax3.set_ylabel('Wind Speed (knots)', fontsize=12)
|
88 |
+
ax3.set_xticks(range(len(models)))
|
89 |
+
ax3.set_xticklabels(models, fontsize=12, rotation=45, ha='right')
|
90 |
+
ax3.grid(axis='y', linestyle='--', alpha=0.3, color='lightgray')
|
91 |
+
|
92 |
+
# Add values on top of the bars for predictions
|
93 |
+
for i, bar in enumerate(bars3):
|
94 |
+
height = bar.get_height()
|
95 |
+
ax3.text(bar.get_x() + bar.get_width()/2., height + 0.3,
|
96 |
+
f'{height:.2f}', ha='center', va='bottom', fontsize=12)
|
97 |
+
|
98 |
+
# Add a label at the bottom explaining the reduction percentages
|
99 |
+
fig.text(0.5, 0.01, 'Note: Reduction percentages (↓%) are calculated relative to TCIP-Net (3DCNN)',
|
100 |
+
ha='center', fontsize=12, fontstyle='italic')
|
101 |
+
|
102 |
+
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
|
103 |
+
|
104 |
+
return fig
|
105 |
+
|
106 |
+
# Function to convert matplotlib figure to Streamlit-compatible image
|
107 |
+
def fig_to_streamlit(fig):
|
108 |
+
buf = io.BytesIO()
|
109 |
+
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
110 |
+
buf.seek(0)
|
111 |
+
return buf
|
112 |
+
|
113 |
def extract_local_region(ir_image, center, region_size=95):
|
114 |
h, w = ir_image.shape
|
115 |
half_size = region_size // 2
|
|
|
145 |
for i in range(vmax_2d_array.shape[0]):
|
146 |
np.fill_diagonal(vmax_3d_array[i], vmax_2d_array[i])
|
147 |
|
148 |
+
# Reshape to (N*8, 8, 8, 1)
|
149 |
vmax_3d_array = vmax_3d_array.reshape(-1, 8, 8, 1)
|
150 |
+
# Trim last element if needed (original comment, but not implemented)
|
|
|
151 |
return vmax_3d_array
|
152 |
|
153 |
def process_lat_values(data):
|
|
|
271 |
|
272 |
|
273 |
# ------------------ Streamlit UI ------------------
|
274 |
+
# Configure the page with wide layout and custom title
|
275 |
+
st.set_page_config(
|
276 |
+
page_title="Tropical Cyclone U-Net Wind Speed Predictor",
|
277 |
+
layout="wide",
|
278 |
+
initial_sidebar_state="expanded"
|
279 |
+
)
|
280 |
|
281 |
+
# Main title with emoji and styling
|
282 |
+
st.markdown("<h1 style='text-align: center;'>🌀 Tropical Cyclone U-Net Wind Speed (Intensity) Predictor</h1><br>", unsafe_allow_html=True)
|
283 |
+
|
284 |
+
# Authors section with ORCID links
|
285 |
+
st.markdown("""
|
286 |
+
<div style='text-align: center;'>
|
287 |
+
<p>
|
288 |
+
<b>By:</b>
|
289 |
+
<a href="https://orcid.org/0009-0006-0342-429X" target="_blank" style="text-decoration: none">Dharun Krishna K B</a>,
|
290 |
+
<a href="https://orcid.org/0009-0008-3214-8065" target="_blank" style="text-decoration: none">Nanduri Prudhvi Reddy</a> and
|
291 |
+
<a href="https://orcid.org/0009-0006-9052-3623" target="_blank" style="text-decoration: none">Ravipati Venkata Madan Mohan</a>; School of Computing.<br>
|
292 |
+
<b>Under the guidance of:</b>
|
293 |
+
<a href="https://orcid.org/0000-0003-1969-3559" target="_blank" style="text-decoration: none">Dr. Gowri L</a>,
|
294 |
+
Assistant Professor, School of Computing.<br>
|
295 |
+
SASTRA Deemed University, Thanjavur, Tamil Nadu, India.<br><br>
|
296 |
+
<b>For:</b>
|
297 |
+
Main project titled <i>"Tropical Cyclone Intensity Prediction Using Deep Learning Models"</i><br>
|
298 |
+
May 2025
|
299 |
+
</p>
|
300 |
+
</div>
|
301 |
+
""", unsafe_allow_html=True)
|
302 |
+
|
303 |
+
# Add a divider before the main content
|
304 |
+
st.markdown('<div class="divider"></div>', unsafe_allow_html=True)
|
305 |
+
|
306 |
+
# Add spacing
|
307 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
308 |
+
|
309 |
+
# App description
|
310 |
+
st.info('''The *Tropical Cyclone Wind Speed Predictor interface* enables the prediction of maximum sustained wind speeds of tropical cyclones (in knots) using IR and PMW imagery, along with physical attributes from the past 24 hours, while also facilitating comparison between state-of-the-art models and our proposed model.
|
311 |
+
''')
|
312 |
|
313 |
ir_images = st.file_uploader("Upload 8 IR images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
|
314 |
pmw_images = st.file_uploader("Upload 8 PMW images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
|
|
|
343 |
vmax_values = np.array(vmax_values)
|
344 |
|
345 |
st.success("CSV file loaded and processed successfully!")
|
346 |
+
|
347 |
+
# Display the dataframe in a scrollable container
|
348 |
+
st.markdown("<h4>Preview of uploaded data:</h4>", unsafe_allow_html=True)
|
349 |
+
preview_df = df.head(10).reset_index(drop=True)
|
350 |
+
preview_df.index += 1 # Shift index to start from 1
|
351 |
+
st.dataframe(preview_df, height=200)
|
352 |
|
353 |
else:
|
354 |
st.error("CSV file must contain 'Latitude', 'Longitude', and 'Vmax' columns.")
|
|
|
356 |
st.error(f"Error reading CSV: {e}")
|
357 |
else:
|
358 |
st.warning("Please upload a CSV file.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
+
# Define data for ablation study
|
361 |
+
ablation_data = {
|
362 |
+
"Model": [
|
363 |
+
"TCIP-Net (3DCNN)",
|
364 |
+
"TCIP-Net (ST-LSTM)",
|
365 |
+
"TCIP-Net (ConvLSTM)",
|
366 |
+
"TCIP-Net (TrajGRU)",
|
367 |
+
"TCIP-Net (ConvGRU)",
|
368 |
+
"TCUWSP-Net (Proposed)"
|
369 |
+
],
|
370 |
+
"RMSE": [12.63, 12.52, 12.36, 12.24, 11.17, 8.6549],
|
371 |
+
"MAE": [10.15, 10.12, 9.97, 9.93, 8.92, 6.309]
|
372 |
+
}
|
373 |
+
|
374 |
+
# Improved Prediction Model Section with better UI
|
375 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
376 |
+
st.markdown("<h2 style='text-align: center;'>Select Prediction Model</h2>", unsafe_allow_html=True)
|
377 |
+
|
378 |
+
# Create columns for better layout
|
379 |
+
col1, col2, col3 = st.columns([1, 2, 1])
|
380 |
+
|
381 |
+
with col2:
|
382 |
+
model_choice = st.selectbox(
|
383 |
+
"Choose a model for prediction",
|
384 |
+
("TCIP-Net ConvGRU", "TCIP-Net ConvLSTM", "TCIP-Net Traj-GRU", "TCIP-Net 3DCNN", "TCIP-Net Spatio-temporal LSTM", "TCUWSP-Net (Proposed Model)"),
|
385 |
+
index=0
|
386 |
+
)
|
387 |
+
|
388 |
+
# Center-aligned, more attractive submit button
|
389 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
390 |
+
col_btn1, col_btn2 = st.columns(2)
|
391 |
+
with col_btn1:
|
392 |
+
submit_button = st.button("Predict Intensity", use_container_width=True)
|
393 |
+
with col_btn2:
|
394 |
+
all_models_button = st.button("Predict Intensity for All Models", use_container_width=True) # ------------------ Process Single Model Button ------------------
|
395 |
+
if submit_button:
|
396 |
+
if len(ir_images) == 8 and len(pmw_images) == 8:
|
397 |
+
# st.success("Starting preprocessing...")
|
398 |
+
if model_choice == "TCUWSP-Net (Proposed Model)":
|
399 |
+
from unetlstm import predict_unetlstm
|
400 |
+
model_predict_fn = predict_unetlstm
|
401 |
+
elif model_choice == "TCIP-Net ConvGRU":
|
402 |
+
from gru_model import predict
|
403 |
+
model_predict_fn = predict
|
404 |
+
elif model_choice == "TCIP-Net ConvLSTM":
|
405 |
+
from convlstm import predict_lstm
|
406 |
+
model_predict_fn = predict_lstm
|
407 |
+
elif model_choice == "TCIP-Net 3DCNN":
|
408 |
+
from cnn3d import predict_3dcnn
|
409 |
+
model_predict_fn = predict_3dcnn
|
410 |
+
elif model_choice == "TCIP-Net Traj-GRU":
|
411 |
+
from trjgru import predict_trajgru
|
412 |
+
model_predict_fn = predict_trajgru
|
413 |
+
elif model_choice == "TCIP-Net Spatio-temporal LSTM":
|
414 |
+
from spaio_temp import predict_stlstm
|
415 |
+
model_predict_fn = predict_stlstm
|
416 |
+
|
417 |
+
ir_arrays = []
|
418 |
+
pmw_arrays = []
|
419 |
+
train_vmax_2d = reshape_vmax(np.array(vmax_values))
|
420 |
+
|
421 |
+
train_vmax_3d= create_3d_vmax(train_vmax_2d)
|
422 |
+
|
423 |
+
lat_processed = process_lat_values(lat_values)
|
424 |
+
lon_processed = process_lon_values(lon_values)
|
425 |
+
|
426 |
+
v_max_diff = calculate_intensity_difference(train_vmax_2d)
|
427 |
+
|
428 |
+
for ir in ir_images:
|
429 |
+
img = Image.open(ir).convert("L")
|
430 |
+
arr = np.array(img).astype(np.float32)
|
431 |
+
bt_arr = (arr / 255.0) * (310 - 190) + 190
|
432 |
+
resized = cv2.resize(bt_arr, (95, 95), interpolation=cv2.INTER_CUBIC)
|
433 |
+
ir_arrays.append(resized)
|
434 |
+
|
435 |
+
for pmw in pmw_images:
|
436 |
+
img = Image.open(pmw).convert("L")
|
437 |
+
arr = np.array(img).astype(np.float32) / 255.0
|
438 |
+
resized = cv2.resize(arr, (95, 95), interpolation=cv2.INTER_CUBIC)
|
439 |
+
pmw_arrays.append(resized)
|
440 |
+
ir=np.array(ir_arrays)
|
441 |
+
pmw=np.array(pmw_arrays)
|
442 |
+
|
443 |
+
# Stack into (8, 95, 95)
|
444 |
+
ir_seq = process_images(ir)
|
445 |
+
pmw_seq = process_images(pmw)
|
446 |
+
|
447 |
+
|
448 |
+
# For demonstration: create batches
|
449 |
+
X_train_new = ir_seq.reshape((1, 8, 95, 95)) # Shape: (1, 8, 95, 95)
|
450 |
+
|
451 |
+
cc_mask= compute_convective_core_masks(X_train_new)
|
452 |
+
hov_m_train = generate_hovmoller(X_train_new)
|
453 |
+
hov_m_train[np.isnan(hov_m_train)] = 0
|
454 |
+
hov_m_train = hov_m_train.transpose(0, 2, 3, 1)
|
455 |
+
|
456 |
+
cc_mask[np.isnan(cc_mask)] = 0
|
457 |
+
cc_mask=cc_mask.reshape(1, 8, 95, 95, 1)
|
458 |
+
i_images=cc_mask+ir_seq
|
459 |
+
reduced_images = np.concatenate([i_images,pmw_seq ], axis=-1)
|
460 |
+
reduced_images[np.isnan(reduced_images)] = 0
|
461 |
|
462 |
+
if model_choice == "Unet_LSTM":
|
463 |
+
import tensorflow as tf
|
464 |
+
|
465 |
+
def tf_gradient_magnitude(images):
|
466 |
+
# Sobel kernels
|
467 |
+
sobel_x = tf.constant([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=tf.float32)
|
468 |
+
sobel_y = tf.constant([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=tf.float32)
|
469 |
+
sobel_x = tf.reshape(sobel_x, [3, 3, 1, 1])
|
470 |
+
sobel_y = tf.reshape(sobel_y, [3, 3, 1, 1])
|
471 |
+
|
472 |
+
images = tf.convert_to_tensor(images, dtype=tf.float32)
|
473 |
+
images = tf.expand_dims(images, -1)
|
474 |
+
|
475 |
+
gx = tf.nn.conv2d(images, sobel_x, strides=1, padding='SAME')
|
476 |
+
gy = tf.nn.conv2d(images, sobel_y, strides=1, padding='SAME')
|
477 |
+
grad_mag = tf.sqrt(tf.square(gx) + tf.square(gy) + 1e-6)
|
478 |
+
|
479 |
+
return tf.squeeze(grad_mag, -1).numpy()
|
480 |
+
def GM_maps_prep(ir):
|
481 |
+
GM_maps=[]
|
482 |
+
for i in ir:
|
483 |
+
GM_map = tf_gradient_magnitude(i)
|
484 |
+
GM_maps.append(GM_map)
|
485 |
+
GM_maps=np.array(GM_maps)
|
486 |
+
return GM_maps
|
487 |
+
ir_seq=ir_seq.reshape(8, 95, 95, 1)
|
488 |
+
GM_maps = GM_maps_prep(ir_seq)
|
489 |
+
print(GM_maps.shape)
|
490 |
+
GM_maps=GM_maps.reshape(1, 8, 95, 95, 1)
|
491 |
+
i_images=cc_mask+ir_seq+GM_maps
|
492 |
+
reduced_images = np.concatenate([i_images,pmw_seq ], axis=-1)
|
493 |
+
reduced_images[np.isnan(reduced_images)] = 0
|
494 |
+
print(reduced_images.shape)
|
495 |
+
y = model_predict_fn(reduced_images, hov_m_train, train_vmax_3d, lat_processed, lon_processed, v_max_diff)
|
496 |
+
else:
|
497 |
+
y = model_predict_fn(reduced_images, hov_m_train, train_vmax_3d, lat_processed, lon_processed, v_max_diff)
|
498 |
+
st.write("Predicted Maximum Sustained Wind Speed [Vmax] (in knots):", y)
|
499 |
+
else:
|
500 |
+
st.error("Make sure you uploaded exactly 8 IR and 8 PMW images.")
|
501 |
+
|
502 |
+
# ------------------ Process All Models Button ------------------
|
503 |
+
if all_models_button:
|
504 |
+
if len(ir_images) == 8 and len(pmw_images) == 8:
|
505 |
+
st.info("Running predictions for all models... This may take a moment.")
|
506 |
+
|
507 |
+
# Store all model names and prediction functions
|
508 |
+
all_model_names = [
|
509 |
+
"TCIP-Net (3DCNN)",
|
510 |
+
"TCIP-Net (ST-LSTM)",
|
511 |
+
"TCIP-Net (ConvLSTM)",
|
512 |
+
"TCIP-Net (TrajGRU)",
|
513 |
+
"TCIP-Net (ConvGRU)",
|
514 |
+
"TCUWSP-Net (Proposed)"
|
515 |
+
]
|
516 |
+
|
517 |
+
# Process input data once for all models
|
518 |
+
ir_arrays = []
|
519 |
+
pmw_arrays = []
|
520 |
+
train_vmax_2d = reshape_vmax(np.array(vmax_values))
|
521 |
+
train_vmax_3d = create_3d_vmax(train_vmax_2d)
|
522 |
+
lat_processed = process_lat_values(lat_values)
|
523 |
+
lon_processed = process_lon_values(lon_values)
|
524 |
+
v_max_diff = calculate_intensity_difference(train_vmax_2d)
|
525 |
+
|
526 |
+
for ir in ir_images:
|
527 |
+
img = Image.open(ir).convert("L")
|
528 |
+
arr = np.array(img).astype(np.float32)
|
529 |
+
bt_arr = (arr / 255.0) * (310 - 190) + 190
|
530 |
+
resized = cv2.resize(bt_arr, (95, 95), interpolation=cv2.INTER_CUBIC)
|
531 |
+
ir_arrays.append(resized)
|
532 |
+
|
533 |
+
for pmw in pmw_images:
|
534 |
+
img = Image.open(pmw).convert("L")
|
535 |
+
arr = np.array(img).astype(np.float32) / 255.0
|
536 |
+
resized = cv2.resize(arr, (95, 95), interpolation=cv2.INTER_CUBIC)
|
537 |
+
pmw_arrays.append(resized)
|
538 |
+
|
539 |
+
ir = np.array(ir_arrays)
|
540 |
+
pmw = np.array(pmw_arrays)
|
541 |
+
|
542 |
+
ir_seq = process_images(ir)
|
543 |
+
pmw_seq = process_images(pmw)
|
544 |
+
|
545 |
+
X_train_new = ir_seq.reshape((1, 8, 95, 95))
|
546 |
+
cc_mask = compute_convective_core_masks(X_train_new)
|
547 |
+
hov_m_train = generate_hovmoller(X_train_new)
|
548 |
+
hov_m_train[np.isnan(hov_m_train)] = 0
|
549 |
+
hov_m_train = hov_m_train.transpose(0, 2, 3, 1)
|
550 |
+
|
551 |
+
cc_mask[np.isnan(cc_mask)] = 0
|
552 |
+
cc_mask = cc_mask.reshape(1, 8, 95, 95, 1)
|
553 |
+
i_images = cc_mask + ir_seq
|
554 |
+
reduced_images = np.concatenate([i_images, pmw_seq], axis=-1)
|
555 |
+
reduced_images[np.isnan(reduced_images)] = 0
|
556 |
+
|
557 |
+
# Special processing for Unet_LSTM model if needed
|
558 |
import tensorflow as tf
|
|
|
559 |
def tf_gradient_magnitude(images):
|
560 |
# Sobel kernels
|
561 |
sobel_x = tf.constant([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=tf.float32)
|
|
|
571 |
grad_mag = tf.sqrt(tf.square(gx) + tf.square(gy) + 1e-6)
|
572 |
|
573 |
return tf.squeeze(grad_mag, -1).numpy()
|
574 |
+
|
575 |
def GM_maps_prep(ir):
|
576 |
GM_maps=[]
|
577 |
for i in ir:
|
|
|
579 |
GM_maps.append(GM_map)
|
580 |
GM_maps=np.array(GM_maps)
|
581 |
return GM_maps
|
582 |
+
|
583 |
+
# For Unet_LSTM model
|
584 |
+
ir_seq_reshaped = ir_seq.reshape(8, 95, 95, 1)
|
585 |
+
GM_maps = GM_maps_prep(ir_seq_reshaped)
|
586 |
+
GM_maps = GM_maps.reshape(1, 8, 95, 95, 1)
|
587 |
+
i_images_unet = cc_mask + ir_seq_reshaped + GM_maps
|
588 |
+
reduced_images_unet = np.concatenate([i_images_unet, pmw_seq], axis=-1)
|
589 |
+
reduced_images_unet[np.isnan(reduced_images_unet)] = 0
|
590 |
+
|
591 |
+
# Run predictions for all models
|
592 |
+
predictions = []
|
593 |
+
progress_bar = st.progress(0)
|
594 |
+
|
595 |
+
# Import all prediction functions
|
596 |
+
from cnn3d import predict_3dcnn
|
597 |
+
from spaio_temp import predict_stlstm
|
598 |
+
from convlstm import predict_lstm
|
599 |
+
from trjgru import predict_trajgru
|
600 |
+
from gru_model import predict
|
601 |
+
from unetlstm import predict_unetlstm
|
602 |
+
|
603 |
+
prediction_functions = [
|
604 |
+
predict_3dcnn, # 3DCNN
|
605 |
+
predict_stlstm, # ST-LSTM
|
606 |
+
predict_lstm, # ConvLSTM
|
607 |
+
predict_trajgru, # TrajGRU
|
608 |
+
predict, # ConvGRU
|
609 |
+
predict_unetlstm # TCUWSP-Net
|
610 |
+
]
|
611 |
+
|
612 |
+
# Run predictions
|
613 |
+
for i, predict_fn in enumerate(prediction_functions):
|
614 |
+
progress_bar.progress((i) / len(prediction_functions))
|
615 |
+
|
616 |
+
# Special case for TCUWSP-Net (Proposed Model)
|
617 |
+
if i == 5: # TCUWSP-Net index
|
618 |
+
y = predict_fn(reduced_images_unet, hov_m_train, train_vmax_3d, lat_processed, lon_processed, v_max_diff)
|
619 |
+
else:
|
620 |
+
y = predict_fn(reduced_images, hov_m_train, train_vmax_3d, lat_processed, lon_processed, v_max_diff)
|
621 |
+
|
622 |
+
predictions.append(float(y))
|
623 |
+
|
624 |
+
progress_bar.progress(1.0)
|
625 |
+
|
626 |
+
# Create results DataFrame
|
627 |
+
results_data = {
|
628 |
+
"Model": all_model_names,
|
629 |
+
"RMSE": ablation_data["RMSE"],
|
630 |
+
"MAE": ablation_data["MAE"],
|
631 |
+
"Predicted Vmax (kt)": predictions
|
632 |
+
}
|
633 |
+
|
634 |
+
results_df = pd.DataFrame(results_data)
|
635 |
+
|
636 |
+
# Show DataFrame
|
637 |
+
st.subheader("Prediction Results from All Models")
|
638 |
+
st.dataframe(results_df, use_container_width=True)
|
639 |
+
|
640 |
+
# Generate and display comparison chart
|
641 |
+
st.subheader("Visual Comparison of Models")
|
642 |
+
|
643 |
+
# Prepare data for visualization
|
644 |
+
plot_model_names = [name.replace(" ", "\n") for name in all_model_names]
|
645 |
+
mae_values = results_df["MAE"].tolist()
|
646 |
+
rmse_values = results_df["RMSE"].tolist()
|
647 |
+
predicted_values = results_df["Predicted Vmax (kt)"].tolist()
|
648 |
+
|
649 |
+
# Generate figure
|
650 |
+
fig = generate_comparison_chart(plot_model_names, mae_values, rmse_values, predicted_values)
|
651 |
+
|
652 |
+
# Display figure
|
653 |
+
st.pyplot(fig)
|
654 |
+
|
655 |
+
# Add some interpretation
|
656 |
+
st.subheader("Interpretation")
|
657 |
+
st.write("""
|
658 |
+
- **RMSE and MAE**: Lower values indicate better model performance.
|
659 |
+
- **Percentage Improvements**: Show reduction in error compared to the baseline TCIP-Net (3DCNN) model.
|
660 |
+
- **Predicted Vmax**: The current intensity prediction for the tropical cyclone based on the provided imagery and historical data.
|
661 |
+
""")
|
662 |
+
|
663 |
+
# Highlight best model
|
664 |
+
best_model_idx = rmse_values.index(min(rmse_values))
|
665 |
+
best_model = all_model_names[best_model_idx]
|
666 |
+
best_prediction = predicted_values[best_model_idx]
|
667 |
+
|
668 |
+
st.success(f"🌟 Best performing model: **{best_model}** with RMSE: **{min(rmse_values):.2f} kt** and predicted intensity: **{best_prediction:.2f} kt**")
|
669 |
else:
|
670 |
+
st.error("Make sure you uploaded exactly 8 IR and 8 PMW images.")
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -21,6 +21,7 @@ jsonschema==4.23.0
|
|
21 |
jsonschema-specifications==2024.10.1
|
22 |
keras==3.9.2
|
23 |
libclang==18.1.1
|
|
|
24 |
Markdown==3.8
|
25 |
markdown-it-py==3.0.0
|
26 |
MarkupSafe==3.0.2
|
|
|
21 |
jsonschema-specifications==2024.10.1
|
22 |
keras==3.9.2
|
23 |
libclang==18.1.1
|
24 |
+
matplotlib==3.10.3
|
25 |
Markdown==3.8
|
26 |
markdown-it-py==3.0.0
|
27 |
MarkupSafe==3.0.2
|