Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
import os | |
# Suppress TensorFlow logging for cleaner logs | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
# Load labels | |
with open('tensorflow/labels.txt', 'r') as f: | |
labels = f.read().splitlines() | |
# Function to load the frozen TensorFlow graph | |
def load_frozen_graph(pb_file_path): | |
with tf.io.gfile.GFile(pb_file_path, 'rb') as f: | |
graph_def = tf.compat.v1.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
with tf.Graph().as_default() as graph: | |
tf.import_graph_def(graph_def, name='') | |
return graph | |
# Load the TensorFlow model | |
MODEL_DIR = 'tensorflow/' | |
graph = load_frozen_graph(os.path.join(MODEL_DIR, 'model.pb')) | |
sess = tf.compat.v1.Session(graph=graph) | |
# Identify input and output tensor names | |
# You may need to adjust these based on your model's actual tensor names | |
input_tensor = graph.get_tensor_by_name('input:0') # Replace 'input:0' with your actual input tensor name | |
output_tensor = graph.get_tensor_by_name('output:0') # Replace 'output:0' with your actual output tensor name | |
def preprocess_image(image): | |
# Resize image to the size expected by the model | |
target_size = (320, 320) # Replace with your model's expected input size | |
image = image.resize(target_size) | |
image = np.array(image) | |
image = image / 255.0 # Normalize if required | |
return image | |
def predict(image): | |
image = preprocess_image(image) | |
image = np.expand_dims(image, axis=0) # Add batch dimension | |
predictions = sess.run(output_tensor, feed_dict={input_tensor: image}) | |
predicted_index = np.argmax(predictions, axis=1)[0] | |
predicted_label = labels[predicted_index] | |
confidence = predictions[0][predicted_index] * 100 | |
return {predicted_label: round(confidence, 2)} | |
# Define Gradio interface | |
title = "JunkWaxHero 🦸♂️ - Baseball Card Set Identifier" | |
description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy." | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=gr.outputs.Label(num_top_classes=1), | |
title=title, | |
description=description, | |
examples=[ | |
["examples/card1.jpg"], | |
["examples/card2.jpg"], | |
["examples/card3.jpg"] | |
], | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
iface.launch() | |