Eric P. Nusbaum
Fix Dependencies
ecce323
raw
history blame
2.41 kB
# app.py
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import os
# Suppress TensorFlow logging for cleaner logs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Load labels
with open('tensorflow/labels.txt', 'r') as f:
labels = f.read().splitlines()
# Function to load the frozen TensorFlow graph
def load_frozen_graph(pb_file_path):
with tf.io.gfile.GFile(pb_file_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
# Load the TensorFlow model
MODEL_DIR = 'tensorflow/'
graph = load_frozen_graph(os.path.join(MODEL_DIR, 'model.pb'))
sess = tf.compat.v1.Session(graph=graph)
# Identify input and output tensor names
# You may need to adjust these based on your model's actual tensor names
input_tensor = graph.get_tensor_by_name('input:0') # Replace 'input:0' with your actual input tensor name
output_tensor = graph.get_tensor_by_name('output:0') # Replace 'output:0' with your actual output tensor name
def preprocess_image(image):
# Resize image to the size expected by the model
target_size = (320, 320) # Replace with your model's expected input size
image = image.resize(target_size)
image = np.array(image)
image = image / 255.0 # Normalize if required
return image
def predict(image):
image = preprocess_image(image)
image = np.expand_dims(image, axis=0) # Add batch dimension
predictions = sess.run(output_tensor, feed_dict={input_tensor: image})
predicted_index = np.argmax(predictions, axis=1)[0]
predicted_label = labels[predicted_index]
confidence = predictions[0][predicted_index] * 100
return {predicted_label: round(confidence, 2)}
# Define Gradio interface
title = "JunkWaxHero 🦸‍♂️ - Baseball Card Set Identifier"
description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy."
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Label(num_top_classes=1),
title=title,
description=description,
examples=[
["examples/card1.jpg"],
["examples/card2.jpg"],
["examples/card3.jpg"]
],
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()