Update trjgru.py
Browse files
trjgru.py
CHANGED
@@ -285,16 +285,25 @@ import h5py
|
|
285 |
|
286 |
model = build_combined_model() # Your original model building function
|
287 |
# Rebuild the model architecture
|
288 |
-
model
|
|
|
289 |
|
290 |
-
#
|
291 |
-
|
292 |
-
_ = model(dummy_input) # Forward pass to build all layers
|
293 |
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
-
model.load_weights(r"Trj_GRU.weights.h5")
|
298 |
|
299 |
|
300 |
def predict_trajgru(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
|
|
|
285 |
|
286 |
model = build_combined_model() # Your original model building function
|
287 |
# Rebuild the model architecture
|
288 |
+
# Step 1: Build the full combined model (with 6 inputs)
|
289 |
+
# model = build_combined_model()
|
290 |
|
291 |
+
# Step 2: Call the model once with dummy data to build the weights
|
292 |
+
# import tensorflow as tf
|
|
|
293 |
|
294 |
+
dummy_input = [
|
295 |
+
tf.random.normal((1, 8, 95, 95, 2)), # reduced_images_test
|
296 |
+
tf.random.normal((1, 95, 95, 8)), # hov_m_test
|
297 |
+
tf.random.normal((1, 8, 8, 1)), # test_vmax_3d
|
298 |
+
tf.random.normal((1, 8)), # lat_test
|
299 |
+
tf.random.normal((1, 8)), # lon_test
|
300 |
+
tf.random.normal((1, 9)), # other_scalar_inputs
|
301 |
+
]
|
302 |
+
_ = model(dummy_input) # Build model by doing one forward pass
|
303 |
+
|
304 |
+
# Step 3: Load weights
|
305 |
+
model.load_weights("Trj_GRU.weights.h5") # Make sure this matches the architecture
|
306 |
|
|
|
307 |
|
308 |
|
309 |
def predict_trajgru(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
|