|
import tensorflow as tf |
|
import numpy as np |
|
import pickle |
|
from PIL import Image |
|
import gradio as gr |
|
from tensorflow.keras.models import Model |
|
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input |
|
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
|
|
|
|
mobilenet_model = MobileNetV2(weights="imagenet", include_top=False, pooling='avg') |
|
mobilenet_model = Model(inputs=mobilenet_model.inputs, outputs=mobilenet_model.output) |
|
|
|
|
|
model = tf.keras.models.load_model("model_9.h5") |
|
|
|
|
|
with open("tokenizer.pkl", "rb") as tokenizer_file: |
|
tokenizer = pickle.load(tokenizer_file) |
|
|
|
|
|
max_caption_length = 33 |
|
start_token = "startseq" |
|
end_token = "endseq" |
|
|
|
|
|
def get_word_from_index(index, tokenizer): |
|
for word, idx in tokenizer.word_index.items(): |
|
if idx == index: |
|
return word |
|
return None |
|
|
|
|
|
def preprocess_image(image): |
|
image = image.resize((224, 224)) |
|
image_array = np.array(image) |
|
image_array = np.expand_dims(image_array, axis=0) |
|
image_array = preprocess_input(image_array) |
|
return mobilenet_model.predict(image_array, verbose=0) |
|
|
|
|
|
def generate_caption(image): |
|
|
|
image_features = preprocess_image(image) |
|
|
|
|
|
image_features = image_features.reshape((1, 1280)) |
|
|
|
caption = start_token |
|
for _ in range(max_caption_length): |
|
sequence = tokenizer.texts_to_sequences([caption])[0] |
|
sequence = pad_sequences([sequence], maxlen=max_caption_length) |
|
|
|
|
|
yhat = model.predict([image_features, sequence], verbose=0) |
|
predicted_index = np.argmax(yhat) |
|
predicted_word = get_word_from_index(predicted_index, tokenizer) |
|
|
|
|
|
if predicted_word is None or predicted_word == end_token: |
|
break |
|
caption += " " + predicted_word |
|
|
|
|
|
final_caption = caption.replace(start_token, "").replace(end_token, "").strip() |
|
return final_caption |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_caption, |
|
inputs=gr.Image(type="pil"), |
|
outputs="text", |
|
title="Image Captioning Model", |
|
description="Upload an image, and the model will generate a caption describing it." |
|
) |
|
|
|
iface.launch() |
|
|