revamped
Browse files- app.py +53 -17
- requirements.txt +0 -1
- tokenizer.pkl +3 -0
app.py
CHANGED
@@ -1,34 +1,70 @@
|
|
1 |
import tensorflow as tf
|
2 |
import numpy as np
|
|
|
3 |
from PIL import Image
|
4 |
import gradio as gr
|
|
|
|
|
|
|
5 |
|
6 |
-
#
|
|
|
|
|
|
|
|
|
7 |
model = tf.keras.models.load_model("model.h5")
|
8 |
|
9 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def preprocess_image(image):
|
11 |
-
|
12 |
-
|
13 |
-
image_array = np.
|
14 |
-
image_array =
|
15 |
-
return image_array
|
16 |
|
17 |
-
#
|
18 |
def generate_caption(image):
|
19 |
-
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
|
|
|
|
26 |
|
27 |
-
# Define
|
28 |
iface = gr.Interface(
|
29 |
-
fn=generate_caption,
|
30 |
-
inputs=gr.Image(type="pil"),
|
31 |
-
outputs="text",
|
32 |
title="Image Captioning Model",
|
33 |
description="Upload an image, and the model will generate a caption describing it."
|
34 |
)
|
|
|
1 |
import tensorflow as tf
|
2 |
import numpy as np
|
3 |
+
import pickle
|
4 |
from PIL import Image
|
5 |
import gradio as gr
|
6 |
+
from tensorflow.keras.models import Model
|
7 |
+
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
|
8 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
9 |
|
10 |
+
# Load MobileNetV2 model for feature extraction
|
11 |
+
mobilenet_model = MobileNetV2(weights="imagenet")
|
12 |
+
mobilenet_model = Model(inputs=mobilenet_model.inputs, outputs=mobilenet_model.layers[-2].output)
|
13 |
+
|
14 |
+
# Load the trained captioning model
|
15 |
model = tf.keras.models.load_model("model.h5")
|
16 |
|
17 |
+
# Load the tokenizer
|
18 |
+
with open("tokenizer.pkl", "rb") as tokenizer_file:
|
19 |
+
tokenizer = pickle.load(tokenizer_file)
|
20 |
+
|
21 |
+
# Set maximum caption length and start/end tokens
|
22 |
+
max_caption_length = 33 # Adjust based on your model's training
|
23 |
+
start_token = "startseq"
|
24 |
+
end_token = "endseq"
|
25 |
+
|
26 |
+
# Define a function to get a word from an index
|
27 |
+
def get_word_from_index(index, tokenizer):
|
28 |
+
for word, idx in tokenizer.word_index.items():
|
29 |
+
if idx == index:
|
30 |
+
return word
|
31 |
+
return None
|
32 |
+
|
33 |
+
# Preprocess image and extract features
|
34 |
def preprocess_image(image):
|
35 |
+
image = image.resize((224, 224))
|
36 |
+
image_array = np.array(image)
|
37 |
+
image_array = np.expand_dims(image_array, axis=0)
|
38 |
+
image_array = preprocess_input(image_array)
|
39 |
+
return mobilenet_model.predict(image_array, verbose=0)
|
40 |
|
41 |
+
# Generate caption from the image features
|
42 |
def generate_caption(image):
|
43 |
+
image_features = preprocess_image(image)
|
44 |
|
45 |
+
caption = start_token
|
46 |
+
for _ in range(max_caption_length):
|
47 |
+
sequence = tokenizer.texts_to_sequences([caption])[0]
|
48 |
+
sequence = pad_sequences([sequence], maxlen=max_caption_length)
|
49 |
+
|
50 |
+
yhat = model.predict([image_features, sequence], verbose=0)
|
51 |
+
predicted_index = np.argmax(yhat)
|
52 |
+
predicted_word = get_word_from_index(predicted_index, tokenizer)
|
53 |
+
|
54 |
+
# If no valid word or end token is predicted, stop generation
|
55 |
+
if predicted_word is None or predicted_word == end_token:
|
56 |
+
break
|
57 |
+
caption += " " + predicted_word
|
58 |
|
59 |
+
# Remove start and end tokens for final output
|
60 |
+
final_caption = caption.replace(start_token, "").replace(end_token, "").strip()
|
61 |
+
return final_caption
|
62 |
|
63 |
+
# Define Gradio interface
|
64 |
iface = gr.Interface(
|
65 |
+
fn=generate_caption, # Function to generate caption
|
66 |
+
inputs=gr.Image(type="pil"), # Input an image
|
67 |
+
outputs="text", # Output a text caption
|
68 |
title="Image Captioning Model",
|
69 |
description="Upload an image, and the model will generate a caption describing it."
|
70 |
)
|
requirements.txt
CHANGED
@@ -3,6 +3,5 @@ gradio==5.3.0
|
|
3 |
pandas==2.0.3
|
4 |
Pillow==9.5.0
|
5 |
torch==2.1.0
|
6 |
-
keras==2.15.0
|
7 |
nltk==3.8.1
|
8 |
numpy==1.23.5
|
|
|
3 |
pandas==2.0.3
|
4 |
Pillow==9.5.0
|
5 |
torch==2.1.0
|
|
|
6 |
nltk==3.8.1
|
7 |
numpy==1.23.5
|
tokenizer.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:287c34d4b6f6db43caaba67747a6674377e26f366e19f1c821a5f5be5067a178
|
3 |
+
size 347726
|