kbdharun commited on
Commit
a106fa3
·
verified ·
1 Parent(s): 810dc65

cleanup: update files; feat: add prediction across all models

Browse files

Signed-off-by: K.B.Dharun Krishna <[email protected]>

Files changed (3) hide show
  1. .gitignore +174 -0
  2. app.py +430 -94
  3. requirements.txt +1 -0
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
app.py CHANGED
@@ -1,8 +1,14 @@
1
  import streamlit as st
2
  import numpy as np
 
3
  from PIL import Image
4
  import cv2
5
  from scipy.ndimage import gaussian_filter
 
 
 
 
 
6
 
7
  # ------------------ TC CENTERING UTILS ------------------
8
 
@@ -11,6 +17,99 @@ def find_tc_center(ir_image, smoothing_sigma=3):
11
  min_coords = np.unravel_index(np.argmin(smoothed_image), smoothed_image.shape)
12
  return min_coords[::-1] # Return as (x, y)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def extract_local_region(ir_image, center, region_size=95):
15
  h, w = ir_image.shape
16
  half_size = region_size // 2
@@ -46,10 +145,9 @@ def create_3d_vmax(vmax_2d_array):
46
  for i in range(vmax_2d_array.shape[0]):
47
  np.fill_diagonal(vmax_3d_array[i], vmax_2d_array[i])
48
 
49
- # Reshape to (N*10, 8, 8, 1) and remove the last element
50
  vmax_3d_array = vmax_3d_array.reshape(-1, 8, 8, 1)
51
- # Trim last element
52
-
53
  return vmax_3d_array
54
 
55
  def process_lat_values(data):
@@ -173,9 +271,44 @@ def compute_convective_core_masks(ir_data):
173
 
174
 
175
  # ------------------ Streamlit UI ------------------
176
- st.set_page_config(page_title="TCIR Daily Input", layout="wide")
 
 
 
 
 
177
 
178
- st.title("Tropical Cyclone U-Net Wind Speed (Intensity) Predictor")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  ir_images = st.file_uploader("Upload 8 IR images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
181
  pmw_images = st.file_uploader("Upload 8 PMW images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
@@ -210,7 +343,12 @@ if csv_file is not None:
210
  vmax_values = np.array(vmax_values)
211
 
212
  st.success("CSV file loaded and processed successfully!")
213
- st.write(df.head())
 
 
 
 
 
214
 
215
  else:
216
  st.error("CSV file must contain 'Latitude', 'Longitude', and 'Vmax' columns.")
@@ -218,84 +356,206 @@ if csv_file is not None:
218
  st.error(f"Error reading CSV: {e}")
219
  else:
220
  st.warning("Please upload a CSV file.")
221
- st.header("Select Prediction Model")
222
- model_choice = st.selectbox(
223
- "Choose a model for prediction",
224
- ("ConvGRU", "ConvLSTM", "Traj-GRU","3DCNN","spatiotemporalLSTM","Unet_LSTM"),
225
- index=0
226
- )
227
- # ------------------ Process Button ------------------
228
- if st.button("Submit for Processing"):
229
 
230
- if len(ir_images) == 8 and len(pmw_images) == 8:
231
- # st.success("Starting preprocessing...")
232
- if model_choice == "Unet_LSTM":
233
- from unetlstm import predict_unetlstm
234
- model_predict_fn = predict_unetlstm
235
- elif model_choice == "ConvGRU":
236
- from gru_model import predict
237
- model_predict_fn = predict
238
- elif model_choice == "ConvLSTM":
239
- from convlstm import predict_lstm
240
- model_predict_fn = predict_lstm
241
- elif model_choice == "3DCNN":
242
- from cnn3d import predict_3dcnn
243
- model_predict_fn = predict_3dcnn
244
- elif model_choice == "Traj-GRU":
245
- from trjgru import predict_trajgru
246
- model_predict_fn = predict_trajgru
247
- elif model_choice == "spatiotemporalLSTM":
248
- from spaio_temp import predict_stlstm
249
- model_predict_fn = predict_stlstm
250
-
251
- ir_arrays = []
252
- pmw_arrays = []
253
- train_vmax_2d = reshape_vmax(np.array(vmax_values))
254
-
255
- train_vmax_3d= create_3d_vmax(train_vmax_2d)
256
-
257
- lat_processed = process_lat_values(lat_values)
258
- lon_processed = process_lon_values(lon_values)
259
-
260
- v_max_diff = calculate_intensity_difference(train_vmax_2d)
261
-
262
- for ir in ir_images:
263
- img = Image.open(ir).convert("L")
264
- arr = np.array(img).astype(np.float32)
265
- bt_arr = (arr / 255.0) * (310 - 190) + 190
266
- resized = cv2.resize(bt_arr, (95, 95), interpolation=cv2.INTER_CUBIC)
267
- ir_arrays.append(resized)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
- for pmw in pmw_images:
270
- img = Image.open(pmw).convert("L")
271
- arr = np.array(img).astype(np.float32) / 255.0
272
- resized = cv2.resize(arr, (95, 95), interpolation=cv2.INTER_CUBIC)
273
- pmw_arrays.append(resized)
274
- ir=np.array(ir_arrays)
275
- pmw=np.array(pmw_arrays)
276
-
277
- # Stack into (8, 95, 95)
278
- ir_seq = process_images(ir)
279
- pmw_seq = process_images(pmw)
280
-
281
-
282
- # For demonstration: create batches
283
- X_train_new = ir_seq.reshape((1, 8, 95, 95)) # Shape: (1, 8, 95, 95)
284
-
285
- cc_mask= compute_convective_core_masks(X_train_new)
286
- hov_m_train = generate_hovmoller(X_train_new)
287
- hov_m_train[np.isnan(hov_m_train)] = 0
288
- hov_m_train = hov_m_train.transpose(0, 2, 3, 1)
289
-
290
- cc_mask[np.isnan(cc_mask)] = 0
291
- cc_mask=cc_mask.reshape(1, 8, 95, 95, 1)
292
- i_images=cc_mask+ir_seq
293
- reduced_images = np.concatenate([i_images,pmw_seq ], axis=-1)
294
- reduced_images[np.isnan(reduced_images)] = 0
295
-
296
- if model_choice == "Unet_LSTM":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  import tensorflow as tf
298
-
299
  def tf_gradient_magnitude(images):
300
  # Sobel kernels
301
  sobel_x = tf.constant([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=tf.float32)
@@ -311,6 +571,7 @@ if st.button("Submit for Processing"):
311
  grad_mag = tf.sqrt(tf.square(gx) + tf.square(gy) + 1e-6)
312
 
313
  return tf.squeeze(grad_mag, -1).numpy()
 
314
  def GM_maps_prep(ir):
315
  GM_maps=[]
316
  for i in ir:
@@ -318,17 +579,92 @@ if st.button("Submit for Processing"):
318
  GM_maps.append(GM_map)
319
  GM_maps=np.array(GM_maps)
320
  return GM_maps
321
- ir_seq=ir_seq.reshape(8, 95, 95, 1)
322
- GM_maps = GM_maps_prep(ir_seq)
323
- print(GM_maps.shape)
324
- GM_maps=GM_maps.reshape(1, 8, 95, 95, 1)
325
- i_images=cc_mask+ir_seq+GM_maps
326
- reduced_images = np.concatenate([i_images,pmw_seq ], axis=-1)
327
- reduced_images[np.isnan(reduced_images)] = 0
328
- print(reduced_images.shape)
329
- y = model_predict_fn(reduced_images, hov_m_train, train_vmax_3d, lat_processed, lon_processed, v_max_diff)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  else:
331
- y = model_predict_fn(reduced_images, hov_m_train, train_vmax_3d, lat_processed, lon_processed, v_max_diff)
332
- st.write("Predicted Vmax:", y)
333
- else:
334
- st.error("Make sure you uploaded exactly 8 IR and 8 PMW images.")
 
1
  import streamlit as st
2
  import numpy as np
3
+ import pandas as pd
4
  from PIL import Image
5
  import cv2
6
  from scipy.ndimage import gaussian_filter
7
+ import tensorflow as tf
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+ from matplotlib.figure import Figure
11
+ import base64
12
 
13
  # ------------------ TC CENTERING UTILS ------------------
14
 
 
17
  min_coords = np.unravel_index(np.argmin(smoothed_image), smoothed_image.shape)
18
  return min_coords[::-1] # Return as (x, y)
19
 
20
+ # Function to generate comparison chart
21
+ def generate_comparison_chart(models, mae_values, rmse_values, predicted_values=None):
22
+ # Calculate improvement percentages relative to the first model
23
+ baseline_mae = mae_values[0]
24
+ baseline_rmse = rmse_values[0]
25
+
26
+ mae_improvements = [0] + [((baseline_mae - val) / baseline_mae) * 100 for val in mae_values[1:]]
27
+ rmse_improvements = [0] + [((baseline_rmse - val) / baseline_rmse) * 100 for val in rmse_values[1:]]
28
+
29
+ # Create figure with subplots (2 or 3 depending on if we have predictions)
30
+ if predicted_values:
31
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 8))
32
+ else:
33
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
34
+
35
+ # Plot MAE
36
+ bars1 = ax1.bar(range(len(models)), mae_values, color='skyblue', edgecolor='black')
37
+ ax1.set_title('Mean Absolute Error (MAE)', fontsize=14, fontweight='bold')
38
+ ax1.set_ylabel('MAE (knots)', fontsize=12)
39
+ ax1.set_xticks(range(len(models)))
40
+ ax1.set_xticklabels(models, fontsize=12, rotation=45, ha='right')
41
+ ax1.grid(axis='y', linestyle='--', alpha=0.3, color='lightgray')
42
+ ax1.set_ylim(0, max(mae_values) * 1.2)
43
+
44
+ # Plot RMSE
45
+ bars2 = ax2.bar(range(len(models)), rmse_values, color='lightcoral', edgecolor='black')
46
+ ax2.set_title('Root Mean Square Error (RMSE)', fontsize=14, fontweight='bold')
47
+ ax2.set_ylabel('RMSE (knots)', fontsize=12)
48
+ ax2.set_xticks(range(len(models)))
49
+ ax2.set_xticklabels(models, fontsize=12, rotation=45, ha='right')
50
+ ax2.grid(axis='y', linestyle='--', alpha=0.3, color='lightgray')
51
+ ax2.set_ylim(0, max(rmse_values) * 1.2)
52
+
53
+ # Add values on top of the bars for MAE
54
+ for i, bar in enumerate(bars1):
55
+ height = bar.get_height()
56
+ ax1.text(bar.get_x() + bar.get_width()/2., height + 0.3,
57
+ f'{height:.2f}', ha='center', va='bottom', fontsize=12)
58
+
59
+ # Add improvement percentage for all except the first bar
60
+ if i > 0:
61
+ ax1.text(bar.get_x() + bar.get_width()/2., height/2,
62
+ f'↓{mae_improvements[i]:.1f}%', ha='center', va='center',
63
+ color='blue', fontsize=12, fontweight='bold')
64
+
65
+ # Add values on top of the bars for RMSE
66
+ for i, bar in enumerate(bars2):
67
+ height = bar.get_height()
68
+ ax2.text(bar.get_x() + bar.get_width()/2., height + 0.3,
69
+ f'{height:.2f}', ha='center', va='bottom', fontsize=12)
70
+
71
+ # Add improvement percentage for all except the first bar
72
+ if i > 0:
73
+ ax2.text(bar.get_x() + bar.get_width()/2., height/2,
74
+ f'↓{rmse_improvements[i]:.1f}%', ha='center', va='center',
75
+ color='darkred', fontsize=12, fontweight='bold')
76
+
77
+ # Add horizontal reference lines for best performance
78
+ min_mae = min(mae_values)
79
+ min_rmse = min(rmse_values)
80
+ ax1.axhline(y=min_mae, color='blue', linestyle='--', alpha=0.5)
81
+ ax2.axhline(y=min_rmse, color='red', linestyle='--', alpha=0.5)
82
+
83
+ # Add predictions comparison if provided
84
+ if predicted_values:
85
+ bars3 = ax3.bar(range(len(models)), predicted_values, color='lightgreen', edgecolor='black')
86
+ ax3.set_title('Predicted Vmax', fontsize=14, fontweight='bold')
87
+ ax3.set_ylabel('Wind Speed (knots)', fontsize=12)
88
+ ax3.set_xticks(range(len(models)))
89
+ ax3.set_xticklabels(models, fontsize=12, rotation=45, ha='right')
90
+ ax3.grid(axis='y', linestyle='--', alpha=0.3, color='lightgray')
91
+
92
+ # Add values on top of the bars for predictions
93
+ for i, bar in enumerate(bars3):
94
+ height = bar.get_height()
95
+ ax3.text(bar.get_x() + bar.get_width()/2., height + 0.3,
96
+ f'{height:.2f}', ha='center', va='bottom', fontsize=12)
97
+
98
+ # Add a label at the bottom explaining the reduction percentages
99
+ fig.text(0.5, 0.01, 'Note: Reduction percentages (↓%) are calculated relative to TCIP-Net (3DCNN)',
100
+ ha='center', fontsize=12, fontstyle='italic')
101
+
102
+ plt.tight_layout(rect=[0, 0.03, 1, 0.95])
103
+
104
+ return fig
105
+
106
+ # Function to convert matplotlib figure to Streamlit-compatible image
107
+ def fig_to_streamlit(fig):
108
+ buf = io.BytesIO()
109
+ fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
110
+ buf.seek(0)
111
+ return buf
112
+
113
  def extract_local_region(ir_image, center, region_size=95):
114
  h, w = ir_image.shape
115
  half_size = region_size // 2
 
145
  for i in range(vmax_2d_array.shape[0]):
146
  np.fill_diagonal(vmax_3d_array[i], vmax_2d_array[i])
147
 
148
+ # Reshape to (N*8, 8, 8, 1)
149
  vmax_3d_array = vmax_3d_array.reshape(-1, 8, 8, 1)
150
+ # Trim last element if needed (original comment, but not implemented)
 
151
  return vmax_3d_array
152
 
153
  def process_lat_values(data):
 
271
 
272
 
273
  # ------------------ Streamlit UI ------------------
274
+ # Configure the page with wide layout and custom title
275
+ st.set_page_config(
276
+ page_title="Tropical Cyclone U-Net Wind Speed Predictor",
277
+ layout="wide",
278
+ initial_sidebar_state="expanded"
279
+ )
280
 
281
+ # Main title with emoji and styling
282
+ st.markdown("<h1 style='text-align: center;'>🌀 Tropical Cyclone U-Net Wind Speed (Intensity) Predictor</h1><br>", unsafe_allow_html=True)
283
+
284
+ # Authors section with ORCID links
285
+ st.markdown("""
286
+ <div style='text-align: center;'>
287
+ <p>
288
+ <b>By:</b>
289
+ <a href="https://orcid.org/0009-0006-0342-429X" target="_blank" style="text-decoration: none">Dharun Krishna K B</a>,
290
+ <a href="https://orcid.org/0009-0008-3214-8065" target="_blank" style="text-decoration: none">Nanduri Prudhvi Reddy</a> and
291
+ <a href="https://orcid.org/0009-0006-9052-3623" target="_blank" style="text-decoration: none">Ravipati Venkata Madan Mohan</a>; School of Computing.<br>
292
+ <b>Under the guidance of:</b>
293
+ <a href="https://orcid.org/0000-0003-1969-3559" target="_blank" style="text-decoration: none">Dr. Gowri L</a>,
294
+ Assistant Professor, School of Computing.<br>
295
+ SASTRA Deemed University, Thanjavur, Tamil Nadu, India.<br><br>
296
+ <b>For:</b>
297
+ Main project titled <i>"Tropical Cyclone Intensity Prediction Using Deep Learning Models"</i><br>
298
+ May 2025
299
+ </p>
300
+ </div>
301
+ """, unsafe_allow_html=True)
302
+
303
+ # Add a divider before the main content
304
+ st.markdown('<div class="divider"></div>', unsafe_allow_html=True)
305
+
306
+ # Add spacing
307
+ st.markdown("<br>", unsafe_allow_html=True)
308
+
309
+ # App description
310
+ st.info('''The *Tropical Cyclone Wind Speed Predictor interface* enables the prediction of maximum sustained wind speeds of tropical cyclones (in knots) using IR and PMW imagery, along with physical attributes from the past 24 hours, while also facilitating comparison between state-of-the-art models and our proposed model.
311
+ ''')
312
 
313
  ir_images = st.file_uploader("Upload 8 IR images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
314
  pmw_images = st.file_uploader("Upload 8 PMW images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
 
343
  vmax_values = np.array(vmax_values)
344
 
345
  st.success("CSV file loaded and processed successfully!")
346
+
347
+ # Display the dataframe in a scrollable container
348
+ st.markdown("<h4>Preview of uploaded data:</h4>", unsafe_allow_html=True)
349
+ preview_df = df.head(10).reset_index(drop=True)
350
+ preview_df.index += 1 # Shift index to start from 1
351
+ st.dataframe(preview_df, height=200)
352
 
353
  else:
354
  st.error("CSV file must contain 'Latitude', 'Longitude', and 'Vmax' columns.")
 
356
  st.error(f"Error reading CSV: {e}")
357
  else:
358
  st.warning("Please upload a CSV file.")
 
 
 
 
 
 
 
 
359
 
360
+ # Define data for ablation study
361
+ ablation_data = {
362
+ "Model": [
363
+ "TCIP-Net (3DCNN)",
364
+ "TCIP-Net (ST-LSTM)",
365
+ "TCIP-Net (ConvLSTM)",
366
+ "TCIP-Net (TrajGRU)",
367
+ "TCIP-Net (ConvGRU)",
368
+ "TCUWSP-Net (Proposed)"
369
+ ],
370
+ "RMSE": [12.63, 12.52, 12.36, 12.24, 11.17, 8.6549],
371
+ "MAE": [10.15, 10.12, 9.97, 9.93, 8.92, 6.309]
372
+ }
373
+
374
+ # Improved Prediction Model Section with better UI
375
+ st.markdown("<br>", unsafe_allow_html=True)
376
+ st.markdown("<h2 style='text-align: center;'>Select Prediction Model</h2>", unsafe_allow_html=True)
377
+
378
+ # Create columns for better layout
379
+ col1, col2, col3 = st.columns([1, 2, 1])
380
+
381
+ with col2:
382
+ model_choice = st.selectbox(
383
+ "Choose a model for prediction",
384
+ ("TCIP-Net ConvGRU", "TCIP-Net ConvLSTM", "TCIP-Net Traj-GRU", "TCIP-Net 3DCNN", "TCIP-Net Spatio-temporal LSTM", "TCUWSP-Net (Proposed Model)"),
385
+ index=0
386
+ )
387
+
388
+ # Center-aligned, more attractive submit button
389
+ st.markdown("<br>", unsafe_allow_html=True)
390
+ col_btn1, col_btn2 = st.columns(2)
391
+ with col_btn1:
392
+ submit_button = st.button("Predict Intensity", use_container_width=True)
393
+ with col_btn2:
394
+ all_models_button = st.button("Predict Intensity for All Models", use_container_width=True) # ------------------ Process Single Model Button ------------------
395
+ if submit_button:
396
+ if len(ir_images) == 8 and len(pmw_images) == 8:
397
+ # st.success("Starting preprocessing...")
398
+ if model_choice == "TCUWSP-Net (Proposed Model)":
399
+ from unetlstm import predict_unetlstm
400
+ model_predict_fn = predict_unetlstm
401
+ elif model_choice == "TCIP-Net ConvGRU":
402
+ from gru_model import predict
403
+ model_predict_fn = predict
404
+ elif model_choice == "TCIP-Net ConvLSTM":
405
+ from convlstm import predict_lstm
406
+ model_predict_fn = predict_lstm
407
+ elif model_choice == "TCIP-Net 3DCNN":
408
+ from cnn3d import predict_3dcnn
409
+ model_predict_fn = predict_3dcnn
410
+ elif model_choice == "TCIP-Net Traj-GRU":
411
+ from trjgru import predict_trajgru
412
+ model_predict_fn = predict_trajgru
413
+ elif model_choice == "TCIP-Net Spatio-temporal LSTM":
414
+ from spaio_temp import predict_stlstm
415
+ model_predict_fn = predict_stlstm
416
+
417
+ ir_arrays = []
418
+ pmw_arrays = []
419
+ train_vmax_2d = reshape_vmax(np.array(vmax_values))
420
+
421
+ train_vmax_3d= create_3d_vmax(train_vmax_2d)
422
+
423
+ lat_processed = process_lat_values(lat_values)
424
+ lon_processed = process_lon_values(lon_values)
425
+
426
+ v_max_diff = calculate_intensity_difference(train_vmax_2d)
427
+
428
+ for ir in ir_images:
429
+ img = Image.open(ir).convert("L")
430
+ arr = np.array(img).astype(np.float32)
431
+ bt_arr = (arr / 255.0) * (310 - 190) + 190
432
+ resized = cv2.resize(bt_arr, (95, 95), interpolation=cv2.INTER_CUBIC)
433
+ ir_arrays.append(resized)
434
+
435
+ for pmw in pmw_images:
436
+ img = Image.open(pmw).convert("L")
437
+ arr = np.array(img).astype(np.float32) / 255.0
438
+ resized = cv2.resize(arr, (95, 95), interpolation=cv2.INTER_CUBIC)
439
+ pmw_arrays.append(resized)
440
+ ir=np.array(ir_arrays)
441
+ pmw=np.array(pmw_arrays)
442
+
443
+ # Stack into (8, 95, 95)
444
+ ir_seq = process_images(ir)
445
+ pmw_seq = process_images(pmw)
446
+
447
+
448
+ # For demonstration: create batches
449
+ X_train_new = ir_seq.reshape((1, 8, 95, 95)) # Shape: (1, 8, 95, 95)
450
+
451
+ cc_mask= compute_convective_core_masks(X_train_new)
452
+ hov_m_train = generate_hovmoller(X_train_new)
453
+ hov_m_train[np.isnan(hov_m_train)] = 0
454
+ hov_m_train = hov_m_train.transpose(0, 2, 3, 1)
455
+
456
+ cc_mask[np.isnan(cc_mask)] = 0
457
+ cc_mask=cc_mask.reshape(1, 8, 95, 95, 1)
458
+ i_images=cc_mask+ir_seq
459
+ reduced_images = np.concatenate([i_images,pmw_seq ], axis=-1)
460
+ reduced_images[np.isnan(reduced_images)] = 0
461
 
462
+ if model_choice == "Unet_LSTM":
463
+ import tensorflow as tf
464
+
465
+ def tf_gradient_magnitude(images):
466
+ # Sobel kernels
467
+ sobel_x = tf.constant([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=tf.float32)
468
+ sobel_y = tf.constant([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=tf.float32)
469
+ sobel_x = tf.reshape(sobel_x, [3, 3, 1, 1])
470
+ sobel_y = tf.reshape(sobel_y, [3, 3, 1, 1])
471
+
472
+ images = tf.convert_to_tensor(images, dtype=tf.float32)
473
+ images = tf.expand_dims(images, -1)
474
+
475
+ gx = tf.nn.conv2d(images, sobel_x, strides=1, padding='SAME')
476
+ gy = tf.nn.conv2d(images, sobel_y, strides=1, padding='SAME')
477
+ grad_mag = tf.sqrt(tf.square(gx) + tf.square(gy) + 1e-6)
478
+
479
+ return tf.squeeze(grad_mag, -1).numpy()
480
+ def GM_maps_prep(ir):
481
+ GM_maps=[]
482
+ for i in ir:
483
+ GM_map = tf_gradient_magnitude(i)
484
+ GM_maps.append(GM_map)
485
+ GM_maps=np.array(GM_maps)
486
+ return GM_maps
487
+ ir_seq=ir_seq.reshape(8, 95, 95, 1)
488
+ GM_maps = GM_maps_prep(ir_seq)
489
+ print(GM_maps.shape)
490
+ GM_maps=GM_maps.reshape(1, 8, 95, 95, 1)
491
+ i_images=cc_mask+ir_seq+GM_maps
492
+ reduced_images = np.concatenate([i_images,pmw_seq ], axis=-1)
493
+ reduced_images[np.isnan(reduced_images)] = 0
494
+ print(reduced_images.shape)
495
+ y = model_predict_fn(reduced_images, hov_m_train, train_vmax_3d, lat_processed, lon_processed, v_max_diff)
496
+ else:
497
+ y = model_predict_fn(reduced_images, hov_m_train, train_vmax_3d, lat_processed, lon_processed, v_max_diff)
498
+ st.write("Predicted Maximum Sustained Wind Speed [Vmax] (in knots):", y)
499
+ else:
500
+ st.error("Make sure you uploaded exactly 8 IR and 8 PMW images.")
501
+
502
+ # ------------------ Process All Models Button ------------------
503
+ if all_models_button:
504
+ if len(ir_images) == 8 and len(pmw_images) == 8:
505
+ st.info("Running predictions for all models... This may take a moment.")
506
+
507
+ # Store all model names and prediction functions
508
+ all_model_names = [
509
+ "TCIP-Net (3DCNN)",
510
+ "TCIP-Net (ST-LSTM)",
511
+ "TCIP-Net (ConvLSTM)",
512
+ "TCIP-Net (TrajGRU)",
513
+ "TCIP-Net (ConvGRU)",
514
+ "TCUWSP-Net (Proposed)"
515
+ ]
516
+
517
+ # Process input data once for all models
518
+ ir_arrays = []
519
+ pmw_arrays = []
520
+ train_vmax_2d = reshape_vmax(np.array(vmax_values))
521
+ train_vmax_3d = create_3d_vmax(train_vmax_2d)
522
+ lat_processed = process_lat_values(lat_values)
523
+ lon_processed = process_lon_values(lon_values)
524
+ v_max_diff = calculate_intensity_difference(train_vmax_2d)
525
+
526
+ for ir in ir_images:
527
+ img = Image.open(ir).convert("L")
528
+ arr = np.array(img).astype(np.float32)
529
+ bt_arr = (arr / 255.0) * (310 - 190) + 190
530
+ resized = cv2.resize(bt_arr, (95, 95), interpolation=cv2.INTER_CUBIC)
531
+ ir_arrays.append(resized)
532
+
533
+ for pmw in pmw_images:
534
+ img = Image.open(pmw).convert("L")
535
+ arr = np.array(img).astype(np.float32) / 255.0
536
+ resized = cv2.resize(arr, (95, 95), interpolation=cv2.INTER_CUBIC)
537
+ pmw_arrays.append(resized)
538
+
539
+ ir = np.array(ir_arrays)
540
+ pmw = np.array(pmw_arrays)
541
+
542
+ ir_seq = process_images(ir)
543
+ pmw_seq = process_images(pmw)
544
+
545
+ X_train_new = ir_seq.reshape((1, 8, 95, 95))
546
+ cc_mask = compute_convective_core_masks(X_train_new)
547
+ hov_m_train = generate_hovmoller(X_train_new)
548
+ hov_m_train[np.isnan(hov_m_train)] = 0
549
+ hov_m_train = hov_m_train.transpose(0, 2, 3, 1)
550
+
551
+ cc_mask[np.isnan(cc_mask)] = 0
552
+ cc_mask = cc_mask.reshape(1, 8, 95, 95, 1)
553
+ i_images = cc_mask + ir_seq
554
+ reduced_images = np.concatenate([i_images, pmw_seq], axis=-1)
555
+ reduced_images[np.isnan(reduced_images)] = 0
556
+
557
+ # Special processing for Unet_LSTM model if needed
558
  import tensorflow as tf
 
559
  def tf_gradient_magnitude(images):
560
  # Sobel kernels
561
  sobel_x = tf.constant([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=tf.float32)
 
571
  grad_mag = tf.sqrt(tf.square(gx) + tf.square(gy) + 1e-6)
572
 
573
  return tf.squeeze(grad_mag, -1).numpy()
574
+
575
  def GM_maps_prep(ir):
576
  GM_maps=[]
577
  for i in ir:
 
579
  GM_maps.append(GM_map)
580
  GM_maps=np.array(GM_maps)
581
  return GM_maps
582
+
583
+ # For Unet_LSTM model
584
+ ir_seq_reshaped = ir_seq.reshape(8, 95, 95, 1)
585
+ GM_maps = GM_maps_prep(ir_seq_reshaped)
586
+ GM_maps = GM_maps.reshape(1, 8, 95, 95, 1)
587
+ i_images_unet = cc_mask + ir_seq_reshaped + GM_maps
588
+ reduced_images_unet = np.concatenate([i_images_unet, pmw_seq], axis=-1)
589
+ reduced_images_unet[np.isnan(reduced_images_unet)] = 0
590
+
591
+ # Run predictions for all models
592
+ predictions = []
593
+ progress_bar = st.progress(0)
594
+
595
+ # Import all prediction functions
596
+ from cnn3d import predict_3dcnn
597
+ from spaio_temp import predict_stlstm
598
+ from convlstm import predict_lstm
599
+ from trjgru import predict_trajgru
600
+ from gru_model import predict
601
+ from unetlstm import predict_unetlstm
602
+
603
+ prediction_functions = [
604
+ predict_3dcnn, # 3DCNN
605
+ predict_stlstm, # ST-LSTM
606
+ predict_lstm, # ConvLSTM
607
+ predict_trajgru, # TrajGRU
608
+ predict, # ConvGRU
609
+ predict_unetlstm # TCUWSP-Net
610
+ ]
611
+
612
+ # Run predictions
613
+ for i, predict_fn in enumerate(prediction_functions):
614
+ progress_bar.progress((i) / len(prediction_functions))
615
+
616
+ # Special case for TCUWSP-Net (Proposed Model)
617
+ if i == 5: # TCUWSP-Net index
618
+ y = predict_fn(reduced_images_unet, hov_m_train, train_vmax_3d, lat_processed, lon_processed, v_max_diff)
619
+ else:
620
+ y = predict_fn(reduced_images, hov_m_train, train_vmax_3d, lat_processed, lon_processed, v_max_diff)
621
+
622
+ predictions.append(float(y))
623
+
624
+ progress_bar.progress(1.0)
625
+
626
+ # Create results DataFrame
627
+ results_data = {
628
+ "Model": all_model_names,
629
+ "RMSE": ablation_data["RMSE"],
630
+ "MAE": ablation_data["MAE"],
631
+ "Predicted Vmax (kt)": predictions
632
+ }
633
+
634
+ results_df = pd.DataFrame(results_data)
635
+
636
+ # Show DataFrame
637
+ st.subheader("Prediction Results from All Models")
638
+ st.dataframe(results_df, use_container_width=True)
639
+
640
+ # Generate and display comparison chart
641
+ st.subheader("Visual Comparison of Models")
642
+
643
+ # Prepare data for visualization
644
+ plot_model_names = [name.replace(" ", "\n") for name in all_model_names]
645
+ mae_values = results_df["MAE"].tolist()
646
+ rmse_values = results_df["RMSE"].tolist()
647
+ predicted_values = results_df["Predicted Vmax (kt)"].tolist()
648
+
649
+ # Generate figure
650
+ fig = generate_comparison_chart(plot_model_names, mae_values, rmse_values, predicted_values)
651
+
652
+ # Display figure
653
+ st.pyplot(fig)
654
+
655
+ # Add some interpretation
656
+ st.subheader("Interpretation")
657
+ st.write("""
658
+ - **RMSE and MAE**: Lower values indicate better model performance.
659
+ - **Percentage Improvements**: Show reduction in error compared to the baseline TCIP-Net (3DCNN) model.
660
+ - **Predicted Vmax**: The current intensity prediction for the tropical cyclone based on the provided imagery and historical data.
661
+ """)
662
+
663
+ # Highlight best model
664
+ best_model_idx = rmse_values.index(min(rmse_values))
665
+ best_model = all_model_names[best_model_idx]
666
+ best_prediction = predicted_values[best_model_idx]
667
+
668
+ st.success(f"🌟 Best performing model: **{best_model}** with RMSE: **{min(rmse_values):.2f} kt** and predicted intensity: **{best_prediction:.2f} kt**")
669
  else:
670
+ st.error("Make sure you uploaded exactly 8 IR and 8 PMW images.")
 
 
 
requirements.txt CHANGED
@@ -21,6 +21,7 @@ jsonschema==4.23.0
21
  jsonschema-specifications==2024.10.1
22
  keras==3.9.2
23
  libclang==18.1.1
 
24
  Markdown==3.8
25
  markdown-it-py==3.0.0
26
  MarkupSafe==3.0.2
 
21
  jsonschema-specifications==2024.10.1
22
  keras==3.9.2
23
  libclang==18.1.1
24
+ matplotlib==3.10.3
25
  Markdown==3.8
26
  markdown-it-py==3.0.0
27
  MarkupSafe==3.0.2