nanduriprudhvi commited on
Commit
8928aba
·
verified ·
1 Parent(s): fd6f594

Update trjgru.py

Browse files
Files changed (1) hide show
  1. trjgru.py +16 -7
trjgru.py CHANGED
@@ -285,16 +285,25 @@ import h5py
285
 
286
  model = build_combined_model() # Your original model building function
287
  # Rebuild the model architecture
288
- model = build_tgru_model(input_shape=(8, 95, 95, 2))
 
289
 
290
- # Build the model by calling it once (to initialize all weights)
291
- dummy_input = tf.random.normal((1, 8, 95, 95, 2)) # batch_size=1
292
- _ = model(dummy_input) # Forward pass to build all layers
293
 
294
- # Now load the saved weights
295
- # model.load_weights("Trj_GRU.weights.h5")
 
 
 
 
 
 
 
 
 
 
296
 
297
- model.load_weights(r"Trj_GRU.weights.h5")
298
 
299
 
300
  def predict_trajgru(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
 
285
 
286
  model = build_combined_model() # Your original model building function
287
  # Rebuild the model architecture
288
+ # Step 1: Build the full combined model (with 6 inputs)
289
+ # model = build_combined_model()
290
 
291
+ # Step 2: Call the model once with dummy data to build the weights
292
+ # import tensorflow as tf
 
293
 
294
+ dummy_input = [
295
+ tf.random.normal((1, 8, 95, 95, 2)), # reduced_images_test
296
+ tf.random.normal((1, 95, 95, 8)), # hov_m_test
297
+ tf.random.normal((1, 8, 8, 1)), # test_vmax_3d
298
+ tf.random.normal((1, 8)), # lat_test
299
+ tf.random.normal((1, 8)), # lon_test
300
+ tf.random.normal((1, 9)), # other_scalar_inputs
301
+ ]
302
+ _ = model(dummy_input) # Build model by doing one forward pass
303
+
304
+ # Step 3: Load weights
305
+ model.load_weights("Trj_GRU.weights.h5") # Make sure this matches the architecture
306
 
 
307
 
308
 
309
  def predict_trajgru(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):