Eric P. Nusbaum
Updated app.py
2d23a8b
raw
history blame
6.2 kB
# app.py
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import os
# Suppress TensorFlow logging for cleaner logs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Disable GPU usage explicitly to prevent TensorFlow from attempting to access GPU libraries
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# Load labels
with open('tensorflow/labels.txt', 'r') as f:
labels = f.read().splitlines()
# Function to load the frozen TensorFlow graph
def load_frozen_graph(pb_file_path):
with tf.io.gfile.GFile(pb_file_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
# Load the TensorFlow model
MODEL_DIR = 'tensorflow/'
MODEL_PATH = os.path.join(MODEL_DIR, 'model.pb')
graph = load_frozen_graph(MODEL_PATH)
sess = tf.compat.v1.Session(graph=graph)
# Define tensor names based on your model's outputs
input_tensor = graph.get_tensor_by_name('image_tensor:0')
detected_boxes = graph.get_tensor_by_name('detected_boxes:0')
detected_classes = graph.get_tensor_by_name('detected_classes:0')
detected_scores = graph.get_tensor_by_name('detected_scores:0')
# Define the target size based on your model's expected input
TARGET_WIDTH = 320
TARGET_HEIGHT = 320
def preprocess_image(image):
"""
Preprocess the input image:
- Resize to target dimensions
- Convert to numpy array
- Normalize pixel values
- Convert RGB to BGR if required by the model
"""
image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
image_np = np.array(image).astype(np.float32)
image_np = image_np / 255.0 # Normalize to [0,1]
image_np = image_np[:, :, (2, 1, 0)] # Convert RGB to BGR if required
image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
return image_np
def draw_boxes(image, boxes, classes, scores, threshold=0.5):
"""
Draw bounding boxes and labels on the image.
Args:
image (PIL.Image): The original image.
boxes (np.array): Array of bounding boxes.
classes (np.array): Array of class IDs.
scores (np.array): Array of confidence scores.
threshold (float): Confidence threshold to filter detections.
Returns:
PIL.Image: Annotated image.
"""
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
font = ImageFont.load_default()
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
if score < threshold:
continue
# Convert box coordinates from normalized to absolute
ymin, xmin, ymax, xmax = box
left = xmin * image.width
right = xmax * image.width
top = ymin * image.height
bottom = ymax * image.height
# Draw rectangle
draw.rectangle([(left, top), (right, bottom)], outline="red", width=2)
# Prepare label
label = f"{labels[int(cls) - 1]}: {score:.2f}"
# Calculate text size using textbbox
text_bbox = draw.textbbox((0, 0), label, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
# Draw label background
draw.rectangle([(left, top - text_height - 4), (left + text_width + 4, top)], fill="red")
# Draw text
draw.text((left + 2, top - text_height - 2), label, fill="white", font=font)
return image
def predict(image):
"""
Perform inference on the input image and return the annotated image.
Args:
image (PIL.Image): Uploaded image.
Returns:
PIL.Image: Annotated image with bounding boxes and labels.
"""
try:
# Preprocess the image
input_array = preprocess_image(image)
# Run inference
boxes, classes, scores = sess.run(
[detected_boxes, detected_classes, detected_scores],
feed_dict={input_tensor: input_array}
)
# Annotate the image with bounding boxes and labels
annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5)
return annotated_image
except Exception as e:
# Return an error image with the error message
error_image = Image.new('RGB', (500, 500), color=(255, 0, 0))
draw = ImageDraw.Draw(error_image)
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
error_text = "Error during prediction."
# Calculate text size using textbbox
text_bbox = draw.textbbox((0, 0), error_text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
# Center the text
draw.rectangle(
[
((500 - text_width) / 2 - 10, (500 - text_height) / 2 - 10),
((500 + text_width) / 2 + 10, (500 + text_height) / 2 + 10)
],
fill="black"
)
draw.text(
((500 - text_width) / 2, (500 - text_height) / 2),
error_text,
fill="white",
font=font
)
return error_image
# Define Gradio interface using the new API
title = "JunkWaxHero 🦸‍♂️ - Baseball Card Set Identifier"
description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy."
# Verify that example images exist to prevent FileNotFoundError
example_images = ["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"]
valid_examples = [img for img in example_images if os.path.exists(img)]
if not valid_examples:
valid_examples = None # Remove examples if none exist
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title=title,
description=description,
examples=valid_examples,
flagging_mode="never" # Updated parameter
)
if __name__ == "__main__":
iface.launch()