Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
import os | |
# Suppress TensorFlow logging for cleaner logs | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
# Disable GPU usage explicitly to prevent TensorFlow from attempting to access GPU libraries | |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
# Load labels | |
with open('tensorflow/labels.txt', 'r') as f: | |
labels = f.read().splitlines() | |
# Function to load the frozen TensorFlow graph | |
def load_frozen_graph(pb_file_path): | |
with tf.io.gfile.GFile(pb_file_path, 'rb') as f: | |
graph_def = tf.compat.v1.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
with tf.Graph().as_default() as graph: | |
tf.import_graph_def(graph_def, name='') | |
return graph | |
# Load the TensorFlow model | |
MODEL_DIR = 'tensorflow/' | |
MODEL_PATH = os.path.join(MODEL_DIR, 'model.pb') | |
graph = load_frozen_graph(MODEL_PATH) | |
sess = tf.compat.v1.Session(graph=graph) | |
# Define tensor names based on your model's outputs | |
input_tensor = graph.get_tensor_by_name('image_tensor:0') | |
detected_boxes = graph.get_tensor_by_name('detected_boxes:0') | |
detected_classes = graph.get_tensor_by_name('detected_classes:0') | |
detected_scores = graph.get_tensor_by_name('detected_scores:0') | |
# Define the target size based on your model's expected input | |
TARGET_WIDTH = 320 | |
TARGET_HEIGHT = 320 | |
def preprocess_image(image): | |
""" | |
Preprocess the input image: | |
- Resize to target dimensions | |
- Convert to numpy array | |
- Normalize pixel values | |
- Convert RGB to BGR if required by the model | |
""" | |
image = image.resize((TARGET_WIDTH, TARGET_HEIGHT)) | |
image_np = np.array(image).astype(np.float32) | |
image_np = image_np / 255.0 # Normalize to [0,1] | |
image_np = image_np[:, :, (2, 1, 0)] # Convert RGB to BGR if required | |
image_np = np.expand_dims(image_np, axis=0) # Add batch dimension | |
return image_np | |
def draw_boxes(image, boxes, classes, scores, threshold=0.5): | |
""" | |
Draw bounding boxes and labels on the image. | |
Args: | |
image (PIL.Image): The original image. | |
boxes (np.array): Array of bounding boxes. | |
classes (np.array): Array of class IDs. | |
scores (np.array): Array of confidence scores. | |
threshold (float): Confidence threshold to filter detections. | |
Returns: | |
PIL.Image: Annotated image. | |
""" | |
draw = ImageDraw.Draw(image) | |
try: | |
font = ImageFont.truetype("arial.ttf", 15) | |
except IOError: | |
font = ImageFont.load_default() | |
for box, cls, score in zip(boxes[0], classes[0], scores[0]): | |
if score < threshold: | |
continue | |
# Convert box coordinates from normalized to absolute | |
ymin, xmin, ymax, xmax = box | |
left = xmin * image.width | |
right = xmax * image.width | |
top = ymin * image.height | |
bottom = ymax * image.height | |
# Draw rectangle | |
draw.rectangle([(left, top), (right, bottom)], outline="red", width=2) | |
# Prepare label | |
label = f"{labels[int(cls) - 1]}: {score:.2f}" | |
# Calculate text size using textbbox | |
text_bbox = draw.textbbox((0, 0), label, font=font) | |
text_width = text_bbox[2] - text_bbox[0] | |
text_height = text_bbox[3] - text_bbox[1] | |
# Draw label background | |
draw.rectangle([(left, top - text_height - 4), (left + text_width + 4, top)], fill="red") | |
# Draw text | |
draw.text((left + 2, top - text_height - 2), label, fill="white", font=font) | |
return image | |
def predict(image): | |
""" | |
Perform inference on the input image and return the annotated image. | |
Args: | |
image (PIL.Image): Uploaded image. | |
Returns: | |
PIL.Image: Annotated image with bounding boxes and labels. | |
""" | |
try: | |
# Preprocess the image | |
input_array = preprocess_image(image) | |
# Run inference | |
boxes, classes, scores = sess.run( | |
[detected_boxes, detected_classes, detected_scores], | |
feed_dict={input_tensor: input_array} | |
) | |
# Annotate the image with bounding boxes and labels | |
annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5) | |
return annotated_image | |
except Exception as e: | |
# Return an error image with the error message | |
error_image = Image.new('RGB', (500, 500), color=(255, 0, 0)) | |
draw = ImageDraw.Draw(error_image) | |
try: | |
font = ImageFont.truetype("arial.ttf", 20) | |
except IOError: | |
font = ImageFont.load_default() | |
error_text = "Error during prediction." | |
# Calculate text size using textbbox | |
text_bbox = draw.textbbox((0, 0), error_text, font=font) | |
text_width = text_bbox[2] - text_bbox[0] | |
text_height = text_bbox[3] - text_bbox[1] | |
# Center the text | |
draw.rectangle( | |
[ | |
((500 - text_width) / 2 - 10, (500 - text_height) / 2 - 10), | |
((500 + text_width) / 2 + 10, (500 + text_height) / 2 + 10) | |
], | |
fill="black" | |
) | |
draw.text( | |
((500 - text_width) / 2, (500 - text_height) / 2), | |
error_text, | |
fill="white", | |
font=font | |
) | |
return error_image | |
# Define Gradio interface using the new API | |
title = "JunkWaxHero 🦸♂️ - Baseball Card Set Identifier" | |
description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy." | |
# Verify that example images exist to prevent FileNotFoundError | |
example_images = ["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"] | |
valid_examples = [img for img in example_images if os.path.exists(img)] | |
if not valid_examples: | |
valid_examples = None # Remove examples if none exist | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Image(type="pil"), | |
title=title, | |
description=description, | |
examples=valid_examples, | |
flagging_mode="never" # Updated parameter | |
) | |
if __name__ == "__main__": | |
iface.launch() | |