Amite5h commited on
Commit
6d19d50
·
1 Parent(s): 367c682

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +27 -26
model.py CHANGED
@@ -272,7 +272,33 @@ class ImageCaptioningModel(tf.keras.Model):
272
  @property
273
  def metrics(self):
274
  return [self.loss_tracker, self.acc_tracker]
275
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  def load_image_from_path(img_path):
278
  img = tf.io.read_file(img_path)
@@ -312,29 +338,4 @@ def generate_caption(img_path, add_noise=False):
312
  y_inp = y_inp.replace('[start] ', '')
313
  return y_inp
314
 
315
- def get_caption_model():
316
- encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
317
- decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
318
-
319
- cnn_model = CNN_Encoder()
320
-
321
- caption_mode = ImageCaptioningModel(
322
- cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None,
323
- )
324
-
325
- def call_fn(batch, training):
326
- return batch
327
-
328
- caption_mode.call = call_fn
329
- sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40))
330
-
331
- caption_mode((sample_x, sample_y))
332
-
333
- sample_img_embed = caption_mode.cnn_model(sample_x)
334
- sample_enc_out = caption_mode.encoder(sample_img_embed, training=False)
335
- caption_mode.decoder(sample_y, sample_enc_out, training=False)
336
-
337
- caption_mode.load_weights('model.h5')
338
-
339
- return caption_mode
340
 
 
272
  @property
273
  def metrics(self):
274
  return [self.loss_tracker, self.acc_tracker]
275
+
276
+
277
+ def get_caption_model():
278
+ encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
279
+ decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
280
+
281
+ cnn_model = CNN_Encoder()
282
+
283
+ caption_mode = ImageCaptioningModel(
284
+ cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None,
285
+ )
286
+
287
+ def call_fn(batch, training):
288
+ return batch
289
+
290
+ caption_mode.call = call_fn
291
+ sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40))
292
+
293
+ caption_mode((sample_x, sample_y))
294
+
295
+ sample_img_embed = caption_mode.cnn_model(sample_x)
296
+ sample_enc_out = caption_mode.encoder(sample_img_embed, training=False)
297
+ caption_mode.decoder(sample_y, sample_enc_out, training=False)
298
+
299
+ caption_mode.load_weights('model.h5')
300
+
301
+ return caption_mode
302
 
303
  def load_image_from_path(img_path):
304
  img = tf.io.read_file(img_path)
 
338
  y_inp = y_inp.replace('[start] ', '')
339
  return y_inp
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341