File size: 3,126 Bytes
0cb8d44
 
df0cfed
0cb8d44
 
df0cfed
 
 
0cb8d44
f10eb97
 
 
30ab3e7
df0cfed
2dba137
0cb8d44
df0cfed
 
 
 
 
b6ed65b
df0cfed
 
 
 
 
 
 
 
 
 
 
0cb8d44
d9a9b93
df0cfed
d9a9b93
 
 
0cb8d44
df0cfed
0cb8d44
d9a9b93
df0cfed
0cb8d44
d9a9b93
2ad9ae6
 
df0cfed
 
d9a9b93
 
df0cfed
d9a9b93
df0cfed
d9a9b93
df0cfed
 
 
 
 
 
0cb8d44
df0cfed
 
 
0cb8d44
df0cfed
0cb8d44
df0cfed
 
 
0cb8d44
 
 
 
0896dfd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import tensorflow as tf
import numpy as np
import pickle
from PIL import Image
import gradio as gr
from tensorflow.keras.models import Model
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from tensorflow.keras.preprocessing.sequence import pad_sequences

# Load MobileNetV2 model for feature extraction (with pooling and no top layer)
mobilenet_model = MobileNetV2(weights="imagenet", include_top=False, pooling='avg')
mobilenet_model = Model(inputs=mobilenet_model.inputs, outputs=mobilenet_model.output)

# Load the trained captioning model
model = tf.keras.models.load_model("model.h5")

# Load the tokenizer
with open("tokenizer.pkl", "rb") as tokenizer_file:
    tokenizer = pickle.load(tokenizer_file)

# Set maximum caption length and start/end tokens
max_caption_length = 34  # Adjust based on your model's training
start_token = "startseq"
end_token = "endseq"

# Define a function to get a word from an index
def get_word_from_index(index, tokenizer):
    for word, idx in tokenizer.word_index.items():
        if idx == index:
            return word
    return None

# Preprocess image and extract features
def preprocess_image(image):
    image = image.resize((224, 224))  # Resize image to 224x224 for MobileNetV2
    image_array = np.array(image)
    image_array = np.expand_dims(image_array, axis=0)  # Add batch dimension
    image_array = preprocess_input(image_array)  # Normalize image for MobileNetV2
    return mobilenet_model.predict(image_array, verbose=0)  # Extract features

# Generate caption from the image features
def generate_caption(image):
    # Extract image features using MobileNetV2
    image_features = preprocess_image(image)
    
    # Reshape to match the expected input shape for the captioning model (1, 2048)
    image_features = image_features.reshape((1, 1280))
    
    caption = start_token
    for _ in range(max_caption_length):
        sequence = tokenizer.texts_to_sequences([caption])[0]  # Convert caption to sequence
        sequence = pad_sequences([sequence], maxlen=max_caption_length)  # Pad sequence

        # Predict the next word in the sequence
        yhat = model.predict([image_features, sequence], verbose=0)
        predicted_index = np.argmax(yhat)  # Get the index of the predicted word
        predicted_word = get_word_from_index(predicted_index, tokenizer)
        
        # If no valid word or end token is predicted, stop generation
        if predicted_word is None or predicted_word == end_token:
            break
        caption += " " + predicted_word

    # Remove start and end tokens for final output
    final_caption = caption.replace(start_token, "").replace(end_token, "").strip()
    return final_caption

# Define Gradio interface
iface = gr.Interface(
    fn=generate_caption,                   # Function to generate caption
    inputs=gr.Image(type="pil"),           # Input an image
    outputs="text",                        # Output a text caption
    title="Image Captioning Model",
    description="Upload an image, and the model will generate a caption describing it."
)

iface.launch()