# app.py import gradio as gr import tensorflow as tf import numpy as np from PIL import Image, ImageDraw, ImageFont import os # Suppress TensorFlow logging for cleaner logs os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disable GPU usage explicitly to prevent TensorFlow from attempting to access GPU libraries os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Load labels with open('tensorflow/labels.txt', 'r') as f: labels = f.read().splitlines() # Function to load the frozen TensorFlow graph def load_frozen_graph(pb_file_path): with tf.io.gfile.GFile(pb_file_path, 'rb') as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name='') return graph # Load the TensorFlow model MODEL_DIR = 'tensorflow/' MODEL_PATH = os.path.join(MODEL_DIR, 'model.pb') graph = load_frozen_graph(MODEL_PATH) sess = tf.compat.v1.Session(graph=graph) # Define tensor names based on your model's outputs input_tensor = graph.get_tensor_by_name('image_tensor:0') detected_boxes = graph.get_tensor_by_name('detected_boxes:0') detected_classes = graph.get_tensor_by_name('detected_classes:0') detected_scores = graph.get_tensor_by_name('detected_scores:0') # Define the target size based on your model's expected input TARGET_WIDTH = 320 TARGET_HEIGHT = 320 def preprocess_image(image): """ Preprocess the input image: - Resize to target dimensions - Convert to numpy array - Normalize pixel values - Convert RGB to BGR if required by the model """ image = image.resize((TARGET_WIDTH, TARGET_HEIGHT)) image_np = np.array(image).astype(np.float32) image_np = image_np / 255.0 # Normalize to [0,1] image_np = image_np[:, :, (2, 1, 0)] # Convert RGB to BGR if required image_np = np.expand_dims(image_np, axis=0) # Add batch dimension return image_np def draw_boxes(image, boxes, classes, scores, threshold=0.5): """ Draw bounding boxes and labels on the image. Args: image (PIL.Image): The original image. boxes (np.array): Array of bounding boxes. classes (np.array): Array of class IDs. scores (np.array): Array of confidence scores. threshold (float): Confidence threshold to filter detections. Returns: PIL.Image: Annotated image. """ draw = ImageDraw.Draw(image) try: font = ImageFont.truetype("arial.ttf", 15) except IOError: font = ImageFont.load_default() for box, cls, score in zip(boxes[0], classes[0], scores[0]): if score < threshold: continue # Convert box coordinates from normalized to absolute ymin, xmin, ymax, xmax = box left = xmin * image.width right = xmax * image.width top = ymin * image.height bottom = ymax * image.height # Draw rectangle draw.rectangle([(left, top), (right, bottom)], outline="red", width=2) # Prepare label label = f"{labels[int(cls) - 1]}: {score:.2f}" # Calculate text size using textbbox text_bbox = draw.textbbox((0, 0), label, font=font) text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] # Draw label background draw.rectangle([(left, top - text_height - 4), (left + text_width + 4, top)], fill="red") # Draw text draw.text((left + 2, top - text_height - 2), label, fill="white", font=font) return image def predict(image): """ Perform inference on the input image and return the annotated image. Args: image (PIL.Image): Uploaded image. Returns: PIL.Image: Annotated image with bounding boxes and labels. """ try: # Preprocess the image input_array = preprocess_image(image) # Run inference boxes, classes, scores = sess.run( [detected_boxes, detected_classes, detected_scores], feed_dict={input_tensor: input_array} ) # Annotate the image with bounding boxes and labels annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5) return annotated_image except Exception as e: # Return an error image with the error message error_image = Image.new('RGB', (500, 500), color=(255, 0, 0)) draw = ImageDraw.Draw(error_image) try: font = ImageFont.truetype("arial.ttf", 20) except IOError: font = ImageFont.load_default() error_text = "Error during prediction." # Calculate text size using textbbox text_bbox = draw.textbbox((0, 0), error_text, font=font) text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] # Center the text draw.rectangle( [ ((500 - text_width) / 2 - 10, (500 - text_height) / 2 - 10), ((500 + text_width) / 2 + 10, (500 + text_height) / 2 + 10) ], fill="black" ) draw.text( ((500 - text_width) / 2, (500 - text_height) / 2), error_text, fill="white", font=font ) return error_image # Define Gradio interface using the new API title = "JunkWaxHero 🦸‍♂️ - Baseball Card Set Identifier" description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy." # Verify that example images exist to prevent FileNotFoundError example_images = ["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"] valid_examples = [img for img in example_images if os.path.exists(img)] if not valid_examples: valid_examples = None # Remove examples if none exist iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title=title, description=description, examples=valid_examples, flagging_mode="never" # Updated parameter ) if __name__ == "__main__": iface.launch()