Imagecapt / app.py
annie08's picture
revamped
df0cfed
raw
history blame
2.63 kB
import tensorflow as tf
import numpy as np
import pickle
from PIL import Image
import gradio as gr
from tensorflow.keras.models import Model
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from tensorflow.keras.preprocessing.sequence import pad_sequences
# Load MobileNetV2 model for feature extraction
mobilenet_model = MobileNetV2(weights="imagenet")
mobilenet_model = Model(inputs=mobilenet_model.inputs, outputs=mobilenet_model.layers[-2].output)
# Load the trained captioning model
model = tf.keras.models.load_model("model.h5")
# Load the tokenizer
with open("tokenizer.pkl", "rb") as tokenizer_file:
tokenizer = pickle.load(tokenizer_file)
# Set maximum caption length and start/end tokens
max_caption_length = 33 # Adjust based on your model's training
start_token = "startseq"
end_token = "endseq"
# Define a function to get a word from an index
def get_word_from_index(index, tokenizer):
for word, idx in tokenizer.word_index.items():
if idx == index:
return word
return None
# Preprocess image and extract features
def preprocess_image(image):
image = image.resize((224, 224))
image_array = np.array(image)
image_array = np.expand_dims(image_array, axis=0)
image_array = preprocess_input(image_array)
return mobilenet_model.predict(image_array, verbose=0)
# Generate caption from the image features
def generate_caption(image):
image_features = preprocess_image(image)
caption = start_token
for _ in range(max_caption_length):
sequence = tokenizer.texts_to_sequences([caption])[0]
sequence = pad_sequences([sequence], maxlen=max_caption_length)
yhat = model.predict([image_features, sequence], verbose=0)
predicted_index = np.argmax(yhat)
predicted_word = get_word_from_index(predicted_index, tokenizer)
# If no valid word or end token is predicted, stop generation
if predicted_word is None or predicted_word == end_token:
break
caption += " " + predicted_word
# Remove start and end tokens for final output
final_caption = caption.replace(start_token, "").replace(end_token, "").strip()
return final_caption
# Define Gradio interface
iface = gr.Interface(
fn=generate_caption, # Function to generate caption
inputs=gr.Image(type="pil"), # Input an image
outputs="text", # Output a text caption
title="Image Captioning Model",
description="Upload an image, and the model will generate a caption describing it."
)
iface.launch()