Spaces:
Sleeping
Sleeping
File size: 6,195 Bytes
ecce323 34f5d81 7a9ce47 34f5d81 ecce323 2d23a8b 7a9ce47 ecce323 7a9ce47 2d23a8b 7a9ce47 2d23a8b 7a9ce47 2d23a8b 7a9ce47 2d23a8b 7a9ce47 123d436 7a9ce47 2d23a8b 987957a 2d23a8b 987957a 34f5d81 987957a 7a9ce47 34f5d81 2d23a8b 34f5d81 7a9ce47 987957a 34f5d81 2d23a8b 34f5d81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# app.py
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import os
# Suppress TensorFlow logging for cleaner logs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Disable GPU usage explicitly to prevent TensorFlow from attempting to access GPU libraries
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# Load labels
with open('tensorflow/labels.txt', 'r') as f:
labels = f.read().splitlines()
# Function to load the frozen TensorFlow graph
def load_frozen_graph(pb_file_path):
with tf.io.gfile.GFile(pb_file_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
# Load the TensorFlow model
MODEL_DIR = 'tensorflow/'
MODEL_PATH = os.path.join(MODEL_DIR, 'model.pb')
graph = load_frozen_graph(MODEL_PATH)
sess = tf.compat.v1.Session(graph=graph)
# Define tensor names based on your model's outputs
input_tensor = graph.get_tensor_by_name('image_tensor:0')
detected_boxes = graph.get_tensor_by_name('detected_boxes:0')
detected_classes = graph.get_tensor_by_name('detected_classes:0')
detected_scores = graph.get_tensor_by_name('detected_scores:0')
# Define the target size based on your model's expected input
TARGET_WIDTH = 320
TARGET_HEIGHT = 320
def preprocess_image(image):
"""
Preprocess the input image:
- Resize to target dimensions
- Convert to numpy array
- Normalize pixel values
- Convert RGB to BGR if required by the model
"""
image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
image_np = np.array(image).astype(np.float32)
image_np = image_np / 255.0 # Normalize to [0,1]
image_np = image_np[:, :, (2, 1, 0)] # Convert RGB to BGR if required
image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
return image_np
def draw_boxes(image, boxes, classes, scores, threshold=0.5):
"""
Draw bounding boxes and labels on the image.
Args:
image (PIL.Image): The original image.
boxes (np.array): Array of bounding boxes.
classes (np.array): Array of class IDs.
scores (np.array): Array of confidence scores.
threshold (float): Confidence threshold to filter detections.
Returns:
PIL.Image: Annotated image.
"""
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
font = ImageFont.load_default()
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
if score < threshold:
continue
# Convert box coordinates from normalized to absolute
ymin, xmin, ymax, xmax = box
left = xmin * image.width
right = xmax * image.width
top = ymin * image.height
bottom = ymax * image.height
# Draw rectangle
draw.rectangle([(left, top), (right, bottom)], outline="red", width=2)
# Prepare label
label = f"{labels[int(cls) - 1]}: {score:.2f}"
# Calculate text size using textbbox
text_bbox = draw.textbbox((0, 0), label, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
# Draw label background
draw.rectangle([(left, top - text_height - 4), (left + text_width + 4, top)], fill="red")
# Draw text
draw.text((left + 2, top - text_height - 2), label, fill="white", font=font)
return image
def predict(image):
"""
Perform inference on the input image and return the annotated image.
Args:
image (PIL.Image): Uploaded image.
Returns:
PIL.Image: Annotated image with bounding boxes and labels.
"""
try:
# Preprocess the image
input_array = preprocess_image(image)
# Run inference
boxes, classes, scores = sess.run(
[detected_boxes, detected_classes, detected_scores],
feed_dict={input_tensor: input_array}
)
# Annotate the image with bounding boxes and labels
annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5)
return annotated_image
except Exception as e:
# Return an error image with the error message
error_image = Image.new('RGB', (500, 500), color=(255, 0, 0))
draw = ImageDraw.Draw(error_image)
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
error_text = "Error during prediction."
# Calculate text size using textbbox
text_bbox = draw.textbbox((0, 0), error_text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
# Center the text
draw.rectangle(
[
((500 - text_width) / 2 - 10, (500 - text_height) / 2 - 10),
((500 + text_width) / 2 + 10, (500 + text_height) / 2 + 10)
],
fill="black"
)
draw.text(
((500 - text_width) / 2, (500 - text_height) / 2),
error_text,
fill="white",
font=font
)
return error_image
# Define Gradio interface using the new API
title = "JunkWaxHero 🦸♂️ - Baseball Card Set Identifier"
description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy."
# Verify that example images exist to prevent FileNotFoundError
example_images = ["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"]
valid_examples = [img for img in example_images if os.path.exists(img)]
if not valid_examples:
valid_examples = None # Remove examples if none exist
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title=title,
description=description,
examples=valid_examples,
flagging_mode="never" # Updated parameter
)
if __name__ == "__main__":
iface.launch()
|