nanduriprudhvi commited on
Commit
d9fdf57
·
verified ·
1 Parent(s): c5de1d5

Update spaio_temp.py

Browse files
Files changed (1) hide show
  1. spaio_temp.py +326 -326
spaio_temp.py CHANGED
@@ -1,327 +1,327 @@
1
- import tensorflow as tf
2
- from tensorflow.keras import layers, models # type: ignore
3
- import numpy as np
4
-
5
-
6
- class SpatiotemporalLSTMCell(layers.Layer):
7
- """
8
- SpatiotemporalLSTMCell: A custom LSTM cell that captures both spatial and temporal dependencies.
9
- It extends the traditional LSTM by adding a memory state (m_t) that focuses on spatial correlations.
10
- """
11
- def __init__(self, filters, kernel_size, **kwargs):
12
- super().__init__(**kwargs)
13
- self.filters = filters # Number of output filters in the convolution
14
- self.kernel_size = kernel_size # Size of the convolutional kernel
15
-
16
- # Convolutional components for standard LSTM operations
17
- self.conv_xg = layers.Conv2D(filters, kernel_size, padding="same", activation="tanh") # For cell input
18
- self.conv_xi = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For input gate
19
- self.conv_xf = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For forget gate
20
- self.conv_xo = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For output gate
21
-
22
- # Convolutional components for spatiotemporal memory operations
23
- self.conv_xg_st = layers.Conv2D(filters, kernel_size, padding="same", activation="tanh") # For ST cell input
24
- self.conv_xi_st = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For ST input gate
25
- self.conv_xf_st = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For ST forget gate
26
-
27
- # Fusion layer to combine the cell state and spatiotemporal memory
28
- self.conv_fusion = layers.Conv2D(filters, (1, 1), padding="same") # 1x1 conv for dimensionality reduction
29
-
30
- def call(self, inputs, states):
31
- """
32
- Forward pass of the spatiotemporal LSTM cell.
33
-
34
- Args:
35
- inputs: Input tensor of shape [batch_size, height, width, channels]
36
- states: List of previous states [h_t-1, c_t-1, m_t-1]
37
- h_t-1: previous hidden state
38
- c_t-1: previous cell state
39
- m_t-1: previous spatiotemporal memory
40
- """
41
- prev_h, prev_c, prev_m = states
42
-
43
- # Standard LSTM operations
44
- g_t = self.conv_xg(inputs) + self.conv_xg(prev_h) # Cell input activation
45
- i_t = self.conv_xi(inputs) + self.conv_xi(prev_h) # Input gate
46
- f_t = self.conv_xf(inputs) + self.conv_xf(prev_h) # Forget gate
47
- o_t = self.conv_xo(inputs) + self.conv_xo(prev_h) # Output gate
48
-
49
- # Cell state update - bug detected: should use prev_c instead of self.conv_xo(prev_h)
50
- c_t = tf.sigmoid(f_t) * self.conv_xo(prev_h) + tf.sigmoid(i_t) * tf.tanh(g_t)
51
-
52
- # Spatiotemporal memory operations
53
- g_t_st = self.conv_xg_st(inputs) + self.conv_xg_st(prev_m) # ST cell input
54
- i_t_st = self.conv_xi_st(inputs) + self.conv_xi_st(prev_m) # ST input gate
55
- f_t_st = self.conv_xf_st(inputs) + self.conv_xf_st(prev_m) # ST forget gate
56
-
57
- # Spatiotemporal memory update - bug detected: should use prev_m directly instead of self.conv_xf_st(prev_m)
58
- m_t = tf.sigmoid(f_t_st) * self.conv_xf_st(prev_m) + tf.sigmoid(i_t_st) * tf.tanh(g_t_st)
59
-
60
- # Hidden state update by fusing cell state and spatiotemporal memory
61
- h_t = tf.sigmoid(o_t) * tf.tanh(self.conv_fusion(tf.concat([c_t, m_t], axis=-1)))
62
-
63
- return h_t, [h_t, c_t, m_t] # Return the hidden state and all updated states
64
-
65
- class SpatiotemporalLSTM(layers.Layer):
66
- """
67
- SpatiotemporalLSTM: Custom layer that applies the SpatiotemporalLSTMCell to a sequence of inputs.
68
- This processes 3D data with spatial and temporal dimensions.
69
- """
70
- def __init__(self, filters, kernel_size, **kwargs):
71
- super().__init__(**kwargs)
72
- self.cell = SpatiotemporalLSTMCell(filters, kernel_size)
73
-
74
- def call(self, inputs):
75
- """
76
- Forward pass of the SpatiotemporalLSTM layer.
77
-
78
- Args:
79
- inputs: Input tensor of shape [batch_size, time_steps, height, width, channels]
80
- """
81
- batch_size = tf.shape(inputs)[0]
82
- time_steps = inputs.shape[1]
83
- height = inputs.shape[2]
84
- width = inputs.shape[3]
85
- channels = inputs.shape[4]
86
-
87
- # Initialize states with zeros
88
- h_t = tf.zeros((batch_size, height, width, channels)) # Hidden state
89
- c_t = tf.zeros((batch_size, height, width, channels)) # Cell state
90
- m_t = tf.zeros((batch_size, height, width, channels)) # Spatiotemporal memory
91
-
92
- outputs = []
93
- # Process sequence step by step
94
- for t in range(time_steps):
95
- # Apply the cell to the current time step and previous states
96
- h_t, [h_t, c_t, m_t] = self.cell(inputs[:, t], [h_t[:,:,:,:inputs.shape[4]],
97
- c_t[:,:,:,:inputs.shape[4]],
98
- m_t[:,:,:,:inputs.shape[4]]])
99
- outputs.append(h_t)
100
-
101
- # Stack outputs along time dimension
102
- return tf.stack(outputs, axis=1)
103
-
104
- def build_st_lstm_model(input_shape=(8, 95, 95, 2)):
105
- """
106
- Build a complete spatiotemporal LSTM model for sequence processing of spatial data.
107
-
108
- Args:
109
- input_shape: Tuple of (time_steps, height, width, channels)
110
-
111
- Returns:
112
- A Keras model with spatiotemporal LSTM layers
113
- """
114
- # Create input layer with fixed batch size
115
- input_tensor = layers.Input(shape=input_shape, batch_size=16)
116
-
117
- # First spatiotemporal LSTM block
118
- st_lstm_layer = SpatiotemporalLSTM(filters=32, kernel_size=(3, 3))
119
- x = st_lstm_layer(input_tensor)
120
- x = layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
121
- x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
122
-
123
- # Second spatiotemporal LSTM block
124
- st_lstm_layer = SpatiotemporalLSTM(filters=64, kernel_size=(3, 3))
125
- x = st_lstm_layer(x)
126
- x = layers.Conv3D(filters=64, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
127
- x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
128
-
129
- # Third spatiotemporal LSTM block
130
- st_lstm_layer = SpatiotemporalLSTM(filters=128, kernel_size=(3, 3))
131
- x = st_lstm_layer(x)
132
- x = layers.Conv3D(filters=128, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
133
- x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
134
-
135
- # Flatten and prepare for output layers (not included in this model)
136
- x = layers.Flatten()(x)
137
-
138
- # Create and return the model
139
- model = models.Model(inputs=input_tensor, outputs=x)
140
- return model
141
-
142
- def radial_structure_subnet(input_shape):
143
- """
144
- Creates the subnet for extracting TC radial structure features using a five-branch CNN design with 2D convolutions.
145
-
146
- Parameters:
147
- - input_shape: tuple, shape of the input data (e.g., (95, 95, 3))
148
-
149
- Returns:
150
- - model: tf.keras.Model, the radial structure subnet model
151
- """
152
-
153
- input_tensor = layers.Input(shape=input_shape)
154
-
155
- # Divide input data into four quadrants (NW, NE, SW, SE)
156
- # Assuming the input shape is (batch_size, height, width, channels)
157
-
158
- # Quadrant extraction - using slicing to separate quadrants
159
- nw_quadrant = input_tensor[:, :input_shape[0]//2, :input_shape[1]//2, :]
160
- ne_quadrant = input_tensor[:, :input_shape[0]//2, input_shape[1]//2:, :]
161
- sw_quadrant = input_tensor[:, input_shape[0]//2:, :input_shape[1]//2, :]
162
- se_quadrant = input_tensor[:, input_shape[0]//2:, input_shape[1]//2:, :]
163
-
164
-
165
- target_height = max(input_shape[0]//2, input_shape[0] - input_shape[0]//2) # 48
166
- target_width = max(input_shape[1]//2, input_shape[1] - input_shape[1]//2) # 48
167
-
168
- # Padding the quadrants to match the target size (48, 48)
169
- nw_quadrant = layers.ZeroPadding2D(padding=((0, target_height - nw_quadrant.shape[1]),
170
- (0, target_width - nw_quadrant.shape[2])))(nw_quadrant)
171
- ne_quadrant = layers.ZeroPadding2D(padding=((0, target_height - ne_quadrant.shape[1]),
172
- (0, target_width - ne_quadrant.shape[2])))(ne_quadrant)
173
- sw_quadrant = layers.ZeroPadding2D(padding=((0, target_height - sw_quadrant.shape[1]),
174
- (0, target_width - sw_quadrant.shape[2])))(sw_quadrant)
175
- se_quadrant = layers.ZeroPadding2D(padding=((0, target_height - se_quadrant.shape[1]),
176
- (0, target_width - se_quadrant.shape[2])))(se_quadrant)
177
-
178
- print(nw_quadrant.shape)
179
- print(ne_quadrant.shape)
180
- print(sw_quadrant.shape)
181
- print(se_quadrant.shape)
182
- # Main branch (processing the entire structure)
183
- main_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(input_tensor)
184
- y=layers.MaxPool2D()(main_branch)
185
-
186
- y = layers.ZeroPadding2D(padding=((0, target_height - y.shape[1]),
187
- (0, target_width - y.shape[2])))(y)
188
- # Side branches (processing the individual quadrants)
189
- nw_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(nw_quadrant)
190
- ne_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(ne_quadrant)
191
- sw_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(sw_quadrant)
192
- se_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(se_quadrant)
193
-
194
- # Apply padding to the side branches to match the dimensions of the main branch
195
- # nw_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(nw_branch)
196
- # ne_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(ne_branch)
197
- # sw_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(sw_branch)
198
- # se_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(se_branch)
199
-
200
- # Fusion operations (concatenate the outputs from the main branch and side branches)
201
- fusion = layers.concatenate([y, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
202
-
203
- # Additional convolution layer to combine the fused features
204
- x = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(fusion)
205
- x=layers.MaxPool2D(pool_size=(2, 2))(x)
206
- # Final dense layer for further processing
207
- nw_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(nw_branch)
208
-
209
- ne_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(ne_branch)
210
- sw_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(sw_branch)
211
- se_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(se_branch)
212
- nw_branch = layers.MaxPool2D(pool_size=(2, 2))(nw_branch)
213
- ne_branch = layers.MaxPool2D(pool_size=(2, 2))(ne_branch)
214
- sw_branch = layers.MaxPool2D(pool_size=(2, 2))(sw_branch)
215
- se_branch = layers.MaxPool2D(pool_size=(2, 2))(se_branch)
216
-
217
- fusion = layers.concatenate([x, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
218
- x = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(fusion)
219
- x=layers.MaxPool2D(pool_size=(2, 2))(x)
220
-
221
- nw_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(nw_branch)
222
-
223
- ne_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(ne_branch)
224
- sw_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(sw_branch)
225
- se_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(se_branch)
226
- nw_branch = layers.MaxPool2D(pool_size=(2, 2))(nw_branch)
227
- ne_branch = layers.MaxPool2D(pool_size=(2, 2))(ne_branch)
228
- sw_branch = layers.MaxPool2D(pool_size=(2, 2))(sw_branch)
229
- se_branch = layers.MaxPool2D(pool_size=(2, 2))(se_branch)
230
-
231
- fusion = layers.concatenate([x, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
232
- x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(fusion)
233
- x=layers.Conv2D(filters=32, kernel_size=(3, 3), activation=None)(x)
234
- # Create and return the model
235
- x=layers.Flatten()(x)
236
- model = models.Model(inputs=input_tensor, outputs=x)
237
- return model
238
-
239
- # Define input shape (batch_size, height, width, channels)
240
- # input_shape = (95, 95, 8) # Example input shape (95x95 spatial resolution, 3 channels)
241
-
242
- # # Build the model
243
- # model = radial_structure_subnet(input_shape)
244
-
245
- # # Model summary
246
- # model.summary()
247
-
248
- def build_cnn_model(input_shape=(8, 8, 1)):
249
- # Define the input layer
250
- input_tensor = layers.Input(shape=input_shape)
251
-
252
- # Convolutional layer
253
- x = layers.Conv2D(64, (3, 3), padding='same')(input_tensor)
254
- x = layers.BatchNormalization()(x)
255
- x = layers.ReLU()(x)
256
-
257
- # Flatten layer
258
- x = layers.Flatten()(x)
259
-
260
- # Create the model
261
- model = models.Model(inputs=input_tensor, outputs=x)
262
-
263
- return model
264
-
265
- from tensorflow.keras import layers, models, Input # type: ignore
266
-
267
- def build_combined_model():
268
- # Define input shapes
269
- input_shape_3d = (8, 95, 95, 2)
270
- input_shape_radial = (95, 95, 8)
271
- input_shape_cnn = (8, 8, 1)
272
-
273
- input_shape_latitude = (8,)
274
- input_shape_longitude = (8,)
275
- input_shape_other = (9,)
276
-
277
- # Build individual models
278
- model_3d = build_st_lstm_model(input_shape=input_shape_3d)
279
- model_radial = radial_structure_subnet(input_shape=input_shape_radial)
280
- model_cnn = build_cnn_model(input_shape=input_shape_cnn)
281
-
282
- # Define new inputs
283
- input_latitude = Input(shape=input_shape_latitude ,name="latitude_input")
284
- input_longitude = Input(shape=input_shape_longitude, name="longitude_input")
285
- input_other = Input(shape=input_shape_other, name="other_input")
286
-
287
- # Flatten the additional inputs
288
- flat_latitude = layers.Dense(32,activation='relu')(input_latitude)
289
- flat_longitude = layers.Dense(32,activation='relu')(input_longitude)
290
- flat_other = layers.Dense(64,activation='relu')(input_other)
291
-
292
- # Combine all outputs
293
- combined = layers.concatenate([
294
- model_3d.output,
295
- model_radial.output,
296
- model_cnn.output,
297
- flat_latitude,
298
- flat_longitude,
299
- flat_other
300
- ])
301
-
302
- # Add dense layers for final processing
303
- x = layers.Dense(128, activation='relu')(combined)
304
- x = layers.Dense(1, activation=None)(x)
305
-
306
- # Create the final model
307
- final_model = models.Model(
308
- inputs=[model_3d.input, model_radial.input, model_cnn.input,
309
- input_latitude, input_longitude, input_other ],
310
- outputs=x
311
- )
312
-
313
- return final_model
314
-
315
- import h5py
316
- with h5py.File(r"E:\1MAIN PROJECT\tf_env\spatio_tempral_LSTM.h5", 'r') as f:
317
- print(f.attrs.get('keras_version'))
318
- print(f.attrs.get('backend'))
319
- print("Model layers:", list(f['model_weights'].keys()))
320
-
321
- model = build_combined_model() # Your original model building function
322
- model.load_weights(r"E:\1MAIN PROJECT\tf_env\spatio_tempral_LSTM.h5")
323
-
324
-
325
- def predict_stlstm(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
326
- y=model.predict([reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test ])
327
  return y
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, models # type: ignore
3
+ import numpy as np
4
+
5
+
6
+ class SpatiotemporalLSTMCell(layers.Layer):
7
+ """
8
+ SpatiotemporalLSTMCell: A custom LSTM cell that captures both spatial and temporal dependencies.
9
+ It extends the traditional LSTM by adding a memory state (m_t) that focuses on spatial correlations.
10
+ """
11
+ def __init__(self, filters, kernel_size, **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.filters = filters # Number of output filters in the convolution
14
+ self.kernel_size = kernel_size # Size of the convolutional kernel
15
+
16
+ # Convolutional components for standard LSTM operations
17
+ self.conv_xg = layers.Conv2D(filters, kernel_size, padding="same", activation="tanh") # For cell input
18
+ self.conv_xi = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For input gate
19
+ self.conv_xf = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For forget gate
20
+ self.conv_xo = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For output gate
21
+
22
+ # Convolutional components for spatiotemporal memory operations
23
+ self.conv_xg_st = layers.Conv2D(filters, kernel_size, padding="same", activation="tanh") # For ST cell input
24
+ self.conv_xi_st = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For ST input gate
25
+ self.conv_xf_st = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For ST forget gate
26
+
27
+ # Fusion layer to combine the cell state and spatiotemporal memory
28
+ self.conv_fusion = layers.Conv2D(filters, (1, 1), padding="same") # 1x1 conv for dimensionality reduction
29
+
30
+ def call(self, inputs, states):
31
+ """
32
+ Forward pass of the spatiotemporal LSTM cell.
33
+
34
+ Args:
35
+ inputs: Input tensor of shape [batch_size, height, width, channels]
36
+ states: List of previous states [h_t-1, c_t-1, m_t-1]
37
+ h_t-1: previous hidden state
38
+ c_t-1: previous cell state
39
+ m_t-1: previous spatiotemporal memory
40
+ """
41
+ prev_h, prev_c, prev_m = states
42
+
43
+ # Standard LSTM operations
44
+ g_t = self.conv_xg(inputs) + self.conv_xg(prev_h) # Cell input activation
45
+ i_t = self.conv_xi(inputs) + self.conv_xi(prev_h) # Input gate
46
+ f_t = self.conv_xf(inputs) + self.conv_xf(prev_h) # Forget gate
47
+ o_t = self.conv_xo(inputs) + self.conv_xo(prev_h) # Output gate
48
+
49
+ # Cell state update - bug detected: should use prev_c instead of self.conv_xo(prev_h)
50
+ c_t = tf.sigmoid(f_t) * self.conv_xo(prev_h) + tf.sigmoid(i_t) * tf.tanh(g_t)
51
+
52
+ # Spatiotemporal memory operations
53
+ g_t_st = self.conv_xg_st(inputs) + self.conv_xg_st(prev_m) # ST cell input
54
+ i_t_st = self.conv_xi_st(inputs) + self.conv_xi_st(prev_m) # ST input gate
55
+ f_t_st = self.conv_xf_st(inputs) + self.conv_xf_st(prev_m) # ST forget gate
56
+
57
+ # Spatiotemporal memory update - bug detected: should use prev_m directly instead of self.conv_xf_st(prev_m)
58
+ m_t = tf.sigmoid(f_t_st) * self.conv_xf_st(prev_m) + tf.sigmoid(i_t_st) * tf.tanh(g_t_st)
59
+
60
+ # Hidden state update by fusing cell state and spatiotemporal memory
61
+ h_t = tf.sigmoid(o_t) * tf.tanh(self.conv_fusion(tf.concat([c_t, m_t], axis=-1)))
62
+
63
+ return h_t, [h_t, c_t, m_t] # Return the hidden state and all updated states
64
+
65
+ class SpatiotemporalLSTM(layers.Layer):
66
+ """
67
+ SpatiotemporalLSTM: Custom layer that applies the SpatiotemporalLSTMCell to a sequence of inputs.
68
+ This processes 3D data with spatial and temporal dimensions.
69
+ """
70
+ def __init__(self, filters, kernel_size, **kwargs):
71
+ super().__init__(**kwargs)
72
+ self.cell = SpatiotemporalLSTMCell(filters, kernel_size)
73
+
74
+ def call(self, inputs):
75
+ """
76
+ Forward pass of the SpatiotemporalLSTM layer.
77
+
78
+ Args:
79
+ inputs: Input tensor of shape [batch_size, time_steps, height, width, channels]
80
+ """
81
+ batch_size = tf.shape(inputs)[0]
82
+ time_steps = inputs.shape[1]
83
+ height = inputs.shape[2]
84
+ width = inputs.shape[3]
85
+ channels = inputs.shape[4]
86
+
87
+ # Initialize states with zeros
88
+ h_t = tf.zeros((batch_size, height, width, channels)) # Hidden state
89
+ c_t = tf.zeros((batch_size, height, width, channels)) # Cell state
90
+ m_t = tf.zeros((batch_size, height, width, channels)) # Spatiotemporal memory
91
+
92
+ outputs = []
93
+ # Process sequence step by step
94
+ for t in range(time_steps):
95
+ # Apply the cell to the current time step and previous states
96
+ h_t, [h_t, c_t, m_t] = self.cell(inputs[:, t], [h_t[:,:,:,:inputs.shape[4]],
97
+ c_t[:,:,:,:inputs.shape[4]],
98
+ m_t[:,:,:,:inputs.shape[4]]])
99
+ outputs.append(h_t)
100
+
101
+ # Stack outputs along time dimension
102
+ return tf.stack(outputs, axis=1)
103
+
104
+ def build_st_lstm_model(input_shape=(8, 95, 95, 2)):
105
+ """
106
+ Build a complete spatiotemporal LSTM model for sequence processing of spatial data.
107
+
108
+ Args:
109
+ input_shape: Tuple of (time_steps, height, width, channels)
110
+
111
+ Returns:
112
+ A Keras model with spatiotemporal LSTM layers
113
+ """
114
+ # Create input layer with fixed batch size
115
+ input_tensor = layers.Input(shape=input_shape, batch_size=16)
116
+
117
+ # First spatiotemporal LSTM block
118
+ st_lstm_layer = SpatiotemporalLSTM(filters=32, kernel_size=(3, 3))
119
+ x = st_lstm_layer(input_tensor)
120
+ x = layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
121
+ x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
122
+
123
+ # Second spatiotemporal LSTM block
124
+ st_lstm_layer = SpatiotemporalLSTM(filters=64, kernel_size=(3, 3))
125
+ x = st_lstm_layer(x)
126
+ x = layers.Conv3D(filters=64, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
127
+ x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
128
+
129
+ # Third spatiotemporal LSTM block
130
+ st_lstm_layer = SpatiotemporalLSTM(filters=128, kernel_size=(3, 3))
131
+ x = st_lstm_layer(x)
132
+ x = layers.Conv3D(filters=128, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
133
+ x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
134
+
135
+ # Flatten and prepare for output layers (not included in this model)
136
+ x = layers.Flatten()(x)
137
+
138
+ # Create and return the model
139
+ model = models.Model(inputs=input_tensor, outputs=x)
140
+ return model
141
+
142
+ def radial_structure_subnet(input_shape):
143
+ """
144
+ Creates the subnet for extracting TC radial structure features using a five-branch CNN design with 2D convolutions.
145
+
146
+ Parameters:
147
+ - input_shape: tuple, shape of the input data (e.g., (95, 95, 3))
148
+
149
+ Returns:
150
+ - model: tf.keras.Model, the radial structure subnet model
151
+ """
152
+
153
+ input_tensor = layers.Input(shape=input_shape)
154
+
155
+ # Divide input data into four quadrants (NW, NE, SW, SE)
156
+ # Assuming the input shape is (batch_size, height, width, channels)
157
+
158
+ # Quadrant extraction - using slicing to separate quadrants
159
+ nw_quadrant = input_tensor[:, :input_shape[0]//2, :input_shape[1]//2, :]
160
+ ne_quadrant = input_tensor[:, :input_shape[0]//2, input_shape[1]//2:, :]
161
+ sw_quadrant = input_tensor[:, input_shape[0]//2:, :input_shape[1]//2, :]
162
+ se_quadrant = input_tensor[:, input_shape[0]//2:, input_shape[1]//2:, :]
163
+
164
+
165
+ target_height = max(input_shape[0]//2, input_shape[0] - input_shape[0]//2) # 48
166
+ target_width = max(input_shape[1]//2, input_shape[1] - input_shape[1]//2) # 48
167
+
168
+ # Padding the quadrants to match the target size (48, 48)
169
+ nw_quadrant = layers.ZeroPadding2D(padding=((0, target_height - nw_quadrant.shape[1]),
170
+ (0, target_width - nw_quadrant.shape[2])))(nw_quadrant)
171
+ ne_quadrant = layers.ZeroPadding2D(padding=((0, target_height - ne_quadrant.shape[1]),
172
+ (0, target_width - ne_quadrant.shape[2])))(ne_quadrant)
173
+ sw_quadrant = layers.ZeroPadding2D(padding=((0, target_height - sw_quadrant.shape[1]),
174
+ (0, target_width - sw_quadrant.shape[2])))(sw_quadrant)
175
+ se_quadrant = layers.ZeroPadding2D(padding=((0, target_height - se_quadrant.shape[1]),
176
+ (0, target_width - se_quadrant.shape[2])))(se_quadrant)
177
+
178
+ print(nw_quadrant.shape)
179
+ print(ne_quadrant.shape)
180
+ print(sw_quadrant.shape)
181
+ print(se_quadrant.shape)
182
+ # Main branch (processing the entire structure)
183
+ main_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(input_tensor)
184
+ y=layers.MaxPool2D()(main_branch)
185
+
186
+ y = layers.ZeroPadding2D(padding=((0, target_height - y.shape[1]),
187
+ (0, target_width - y.shape[2])))(y)
188
+ # Side branches (processing the individual quadrants)
189
+ nw_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(nw_quadrant)
190
+ ne_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(ne_quadrant)
191
+ sw_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(sw_quadrant)
192
+ se_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(se_quadrant)
193
+
194
+ # Apply padding to the side branches to match the dimensions of the main branch
195
+ # nw_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(nw_branch)
196
+ # ne_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(ne_branch)
197
+ # sw_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(sw_branch)
198
+ # se_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(se_branch)
199
+
200
+ # Fusion operations (concatenate the outputs from the main branch and side branches)
201
+ fusion = layers.concatenate([y, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
202
+
203
+ # Additional convolution layer to combine the fused features
204
+ x = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(fusion)
205
+ x=layers.MaxPool2D(pool_size=(2, 2))(x)
206
+ # Final dense layer for further processing
207
+ nw_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(nw_branch)
208
+
209
+ ne_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(ne_branch)
210
+ sw_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(sw_branch)
211
+ se_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(se_branch)
212
+ nw_branch = layers.MaxPool2D(pool_size=(2, 2))(nw_branch)
213
+ ne_branch = layers.MaxPool2D(pool_size=(2, 2))(ne_branch)
214
+ sw_branch = layers.MaxPool2D(pool_size=(2, 2))(sw_branch)
215
+ se_branch = layers.MaxPool2D(pool_size=(2, 2))(se_branch)
216
+
217
+ fusion = layers.concatenate([x, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
218
+ x = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(fusion)
219
+ x=layers.MaxPool2D(pool_size=(2, 2))(x)
220
+
221
+ nw_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(nw_branch)
222
+
223
+ ne_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(ne_branch)
224
+ sw_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(sw_branch)
225
+ se_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(se_branch)
226
+ nw_branch = layers.MaxPool2D(pool_size=(2, 2))(nw_branch)
227
+ ne_branch = layers.MaxPool2D(pool_size=(2, 2))(ne_branch)
228
+ sw_branch = layers.MaxPool2D(pool_size=(2, 2))(sw_branch)
229
+ se_branch = layers.MaxPool2D(pool_size=(2, 2))(se_branch)
230
+
231
+ fusion = layers.concatenate([x, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
232
+ x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(fusion)
233
+ x=layers.Conv2D(filters=32, kernel_size=(3, 3), activation=None)(x)
234
+ # Create and return the model
235
+ x=layers.Flatten()(x)
236
+ model = models.Model(inputs=input_tensor, outputs=x)
237
+ return model
238
+
239
+ # Define input shape (batch_size, height, width, channels)
240
+ # input_shape = (95, 95, 8) # Example input shape (95x95 spatial resolution, 3 channels)
241
+
242
+ # # Build the model
243
+ # model = radial_structure_subnet(input_shape)
244
+
245
+ # # Model summary
246
+ # model.summary()
247
+
248
+ def build_cnn_model(input_shape=(8, 8, 1)):
249
+ # Define the input layer
250
+ input_tensor = layers.Input(shape=input_shape)
251
+
252
+ # Convolutional layer
253
+ x = layers.Conv2D(64, (3, 3), padding='same')(input_tensor)
254
+ x = layers.BatchNormalization()(x)
255
+ x = layers.ReLU()(x)
256
+
257
+ # Flatten layer
258
+ x = layers.Flatten()(x)
259
+
260
+ # Create the model
261
+ model = models.Model(inputs=input_tensor, outputs=x)
262
+
263
+ return model
264
+
265
+ from tensorflow.keras import layers, models, Input # type: ignore
266
+
267
+ def build_combined_model():
268
+ # Define input shapes
269
+ input_shape_3d = (8, 95, 95, 2)
270
+ input_shape_radial = (95, 95, 8)
271
+ input_shape_cnn = (8, 8, 1)
272
+
273
+ input_shape_latitude = (8,)
274
+ input_shape_longitude = (8,)
275
+ input_shape_other = (9,)
276
+
277
+ # Build individual models
278
+ model_3d = build_st_lstm_model(input_shape=input_shape_3d)
279
+ model_radial = radial_structure_subnet(input_shape=input_shape_radial)
280
+ model_cnn = build_cnn_model(input_shape=input_shape_cnn)
281
+
282
+ # Define new inputs
283
+ input_latitude = Input(shape=input_shape_latitude ,name="latitude_input")
284
+ input_longitude = Input(shape=input_shape_longitude, name="longitude_input")
285
+ input_other = Input(shape=input_shape_other, name="other_input")
286
+
287
+ # Flatten the additional inputs
288
+ flat_latitude = layers.Dense(32,activation='relu')(input_latitude)
289
+ flat_longitude = layers.Dense(32,activation='relu')(input_longitude)
290
+ flat_other = layers.Dense(64,activation='relu')(input_other)
291
+
292
+ # Combine all outputs
293
+ combined = layers.concatenate([
294
+ model_3d.output,
295
+ model_radial.output,
296
+ model_cnn.output,
297
+ flat_latitude,
298
+ flat_longitude,
299
+ flat_other
300
+ ])
301
+
302
+ # Add dense layers for final processing
303
+ x = layers.Dense(128, activation='relu')(combined)
304
+ x = layers.Dense(1, activation=None)(x)
305
+
306
+ # Create the final model
307
+ final_model = models.Model(
308
+ inputs=[model_3d.input, model_radial.input, model_cnn.input,
309
+ input_latitude, input_longitude, input_other ],
310
+ outputs=x
311
+ )
312
+
313
+ return final_model
314
+
315
+ import h5py
316
+ with h5py.File(r"spatio_tempral_LSTM.h5", 'r') as f:
317
+ print(f.attrs.get('keras_version'))
318
+ print(f.attrs.get('backend'))
319
+ print("Model layers:", list(f['model_weights'].keys()))
320
+
321
+ model = build_combined_model() # Your original model building function
322
+ model.load_weights(r"spatio_tempral_LSTM.h5")
323
+
324
+
325
+ def predict_stlstm(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
326
+ y=model.predict([reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test ])
327
  return y