Amite5h commited on
Commit
06b2c37
·
1 Parent(s): 2eadf05

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +11 -11
model.py CHANGED
@@ -280,25 +280,25 @@ def get_caption_model():
280
 
281
  cnn_model = CNN_Encoder()
282
 
283
- caption_mode = ImageCaptioningModel(
284
  cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None,
285
  )
286
 
287
  def call_fn(batch, training):
288
  return batch
289
 
290
- caption_mode.call = call_fn
291
  sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40))
292
 
293
- caption_mode((sample_x, sample_y))
294
 
295
- sample_img_embed = caption_mode.cnn_model(sample_x)
296
- sample_enc_out = caption_mode.encoder(sample_img_embed, training=False)
297
- caption_mode.decoder(sample_y, sample_enc_out, training=False)
298
 
299
- caption_mode.load_weights('model.h5')
300
 
301
- return caption_mode
302
 
303
  def load_image_from_path(img_path):
304
  img = tf.io.read_file(img_path)
@@ -317,14 +317,14 @@ def generate_caption(img_path, add_noise=False):
317
  img = (img - tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img))
318
 
319
  img = tf.expand_dims(img, axis=0)
320
- img_embed = model.cnn_model(img)
321
- img_encoded = model.encoder(img_embed, training=False)
322
 
323
  y_inp = '[start]'
324
  for i in range(MAX_LENGTH-1):
325
  tokenized = tokenizer([y_inp])[:, :-1]
326
  mask = tf.cast(tokenized != 0, tf.int32)
327
- pred = model.decoder(
328
  tokenized, img_encoded, training=False, mask=mask)
329
 
330
  pred_idx = np.argmax(pred[0, i, :])
 
280
 
281
  cnn_model = CNN_Encoder()
282
 
283
+ caption_model = ImageCaptioningModel(
284
  cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None,
285
  )
286
 
287
  def call_fn(batch, training):
288
  return batch
289
 
290
+ caption_model.call = call_fn
291
  sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40))
292
 
293
+ caption_model((sample_x, sample_y))
294
 
295
+ sample_img_embed = caption_model.cnn_model(sample_x)
296
+ sample_enc_out = caption_model.encoder(sample_img_embed, training=False)
297
+ caption_model.decoder(sample_y, sample_enc_out, training=False)
298
 
299
+ caption_model.load_weights('model.h5')
300
 
301
+ return caption_model
302
 
303
  def load_image_from_path(img_path):
304
  img = tf.io.read_file(img_path)
 
317
  img = (img - tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img))
318
 
319
  img = tf.expand_dims(img, axis=0)
320
+ img_embed = caption_model.cnn_model(img)
321
+ img_encoded = caption_model.encoder(img_embed, training=False)
322
 
323
  y_inp = '[start]'
324
  for i in range(MAX_LENGTH-1):
325
  tokenized = tokenizer([y_inp])[:, :-1]
326
  mask = tf.cast(tokenized != 0, tf.int32)
327
+ pred = caption_model.decoder(
328
  tokenized, img_encoded, training=False, mask=mask)
329
 
330
  pred_idx = np.argmax(pred[0, i, :])