annie08 commited on
Commit
df0cfed
·
1 Parent(s): 0896dfd
Files changed (3) hide show
  1. app.py +53 -17
  2. requirements.txt +0 -1
  3. tokenizer.pkl +3 -0
app.py CHANGED
@@ -1,34 +1,70 @@
1
  import tensorflow as tf
2
  import numpy as np
 
3
  from PIL import Image
4
  import gradio as gr
 
 
 
5
 
6
- # CNN+LSTM model loaded
 
 
 
 
7
  model = tf.keras.models.load_model("model.h5")
8
 
9
- # Define the preprocessing function for the image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def preprocess_image(image):
11
- # Resize and normalize image to match model's expected input shape
12
- image = image.resize((224, 224)) # Modify size based on your model
13
- image_array = np.array(image) / 255.0 # Normalize
14
- image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
15
- return image_array
16
 
17
- # Define the function that generates a caption from the image
18
  def generate_caption(image):
19
- preprocessed_image = preprocess_image(image)
20
 
21
- # Generate a caption from the model
22
- # Note: Adjust this if your model requires a sequence start token or has a decoding loop
23
- caption_tokens = model.predict(preprocessed_image)
 
 
 
 
 
 
 
 
 
 
24
 
25
- return caption_tokens
 
 
26
 
27
- # Define the Gradio interface
28
  iface = gr.Interface(
29
- fn=generate_caption, # The function that generates captions
30
- inputs=gr.Image(type="pil"), # Accept an image input
31
- outputs="text", # Output a text caption
32
  title="Image Captioning Model",
33
  description="Upload an image, and the model will generate a caption describing it."
34
  )
 
1
  import tensorflow as tf
2
  import numpy as np
3
+ import pickle
4
  from PIL import Image
5
  import gradio as gr
6
+ from tensorflow.keras.models import Model
7
+ from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
8
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
9
 
10
+ # Load MobileNetV2 model for feature extraction
11
+ mobilenet_model = MobileNetV2(weights="imagenet")
12
+ mobilenet_model = Model(inputs=mobilenet_model.inputs, outputs=mobilenet_model.layers[-2].output)
13
+
14
+ # Load the trained captioning model
15
  model = tf.keras.models.load_model("model.h5")
16
 
17
+ # Load the tokenizer
18
+ with open("tokenizer.pkl", "rb") as tokenizer_file:
19
+ tokenizer = pickle.load(tokenizer_file)
20
+
21
+ # Set maximum caption length and start/end tokens
22
+ max_caption_length = 33 # Adjust based on your model's training
23
+ start_token = "startseq"
24
+ end_token = "endseq"
25
+
26
+ # Define a function to get a word from an index
27
+ def get_word_from_index(index, tokenizer):
28
+ for word, idx in tokenizer.word_index.items():
29
+ if idx == index:
30
+ return word
31
+ return None
32
+
33
+ # Preprocess image and extract features
34
  def preprocess_image(image):
35
+ image = image.resize((224, 224))
36
+ image_array = np.array(image)
37
+ image_array = np.expand_dims(image_array, axis=0)
38
+ image_array = preprocess_input(image_array)
39
+ return mobilenet_model.predict(image_array, verbose=0)
40
 
41
+ # Generate caption from the image features
42
  def generate_caption(image):
43
+ image_features = preprocess_image(image)
44
 
45
+ caption = start_token
46
+ for _ in range(max_caption_length):
47
+ sequence = tokenizer.texts_to_sequences([caption])[0]
48
+ sequence = pad_sequences([sequence], maxlen=max_caption_length)
49
+
50
+ yhat = model.predict([image_features, sequence], verbose=0)
51
+ predicted_index = np.argmax(yhat)
52
+ predicted_word = get_word_from_index(predicted_index, tokenizer)
53
+
54
+ # If no valid word or end token is predicted, stop generation
55
+ if predicted_word is None or predicted_word == end_token:
56
+ break
57
+ caption += " " + predicted_word
58
 
59
+ # Remove start and end tokens for final output
60
+ final_caption = caption.replace(start_token, "").replace(end_token, "").strip()
61
+ return final_caption
62
 
63
+ # Define Gradio interface
64
  iface = gr.Interface(
65
+ fn=generate_caption, # Function to generate caption
66
+ inputs=gr.Image(type="pil"), # Input an image
67
+ outputs="text", # Output a text caption
68
  title="Image Captioning Model",
69
  description="Upload an image, and the model will generate a caption describing it."
70
  )
requirements.txt CHANGED
@@ -3,6 +3,5 @@ gradio==5.3.0
3
  pandas==2.0.3
4
  Pillow==9.5.0
5
  torch==2.1.0
6
- keras==2.15.0
7
  nltk==3.8.1
8
  numpy==1.23.5
 
3
  pandas==2.0.3
4
  Pillow==9.5.0
5
  torch==2.1.0
 
6
  nltk==3.8.1
7
  numpy==1.23.5
tokenizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:287c34d4b6f6db43caaba67747a6674377e26f366e19f1c821a5f5be5067a178
3
+ size 347726