annie08 commited on
Commit
d9a9b93
·
1 Parent(s): f10eb97

changed shape

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -11,7 +11,6 @@ from tensorflow.keras.preprocessing.sequence import pad_sequences
11
  mobilenet_model = MobileNetV2(weights="imagenet", include_top=False, pooling='avg')
12
  mobilenet_model = Model(inputs=mobilenet_model.inputs, outputs=mobilenet_model.output)
13
 
14
-
15
  # Load the trained captioning model
16
  model = tf.keras.models.load_model("model_9.h5")
17
 
@@ -33,23 +32,28 @@ def get_word_from_index(index, tokenizer):
33
 
34
  # Preprocess image and extract features
35
  def preprocess_image(image):
36
- image = image.resize((224, 224))
37
  image_array = np.array(image)
38
- image_array = np.expand_dims(image_array, axis=0)
39
- image_array = preprocess_input(image_array)
40
- return mobilenet_model.predict(image_array, verbose=0)
41
 
42
  # Generate caption from the image features
43
  def generate_caption(image):
 
44
  image_features = preprocess_image(image)
45
 
 
 
 
46
  caption = start_token
47
  for _ in range(max_caption_length):
48
- sequence = tokenizer.texts_to_sequences([caption])[0]
49
- sequence = pad_sequences([sequence], maxlen=max_caption_length)
50
 
 
51
  yhat = model.predict([image_features, sequence], verbose=0)
52
- predicted_index = np.argmax(yhat)
53
  predicted_word = get_word_from_index(predicted_index, tokenizer)
54
 
55
  # If no valid word or end token is predicted, stop generation
 
11
  mobilenet_model = MobileNetV2(weights="imagenet", include_top=False, pooling='avg')
12
  mobilenet_model = Model(inputs=mobilenet_model.inputs, outputs=mobilenet_model.output)
13
 
 
14
  # Load the trained captioning model
15
  model = tf.keras.models.load_model("model_9.h5")
16
 
 
32
 
33
  # Preprocess image and extract features
34
  def preprocess_image(image):
35
+ image = image.resize((224, 224)) # Resize image to 224x224 for MobileNetV2
36
  image_array = np.array(image)
37
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
38
+ image_array = preprocess_input(image_array) # Normalize image for MobileNetV2
39
+ return mobilenet_model.predict(image_array, verbose=0) # Extract features
40
 
41
  # Generate caption from the image features
42
  def generate_caption(image):
43
+ # Extract image features using MobileNetV2
44
  image_features = preprocess_image(image)
45
 
46
+ # Reshape to match the expected input shape for the captioning model (1, 2048)
47
+ image_features = image_features.reshape((1, 2048))
48
+
49
  caption = start_token
50
  for _ in range(max_caption_length):
51
+ sequence = tokenizer.texts_to_sequences([caption])[0] # Convert caption to sequence
52
+ sequence = pad_sequences([sequence], maxlen=max_caption_length) # Pad sequence
53
 
54
+ # Predict the next word in the sequence
55
  yhat = model.predict([image_features, sequence], verbose=0)
56
+ predicted_index = np.argmax(yhat) # Get the index of the predicted word
57
  predicted_word = get_word_from_index(predicted_index, tokenizer)
58
 
59
  # If no valid word or end token is predicted, stop generation