chansung's picture
upload 1688113391 model
3c8d862
raw
history blame
2.15 kB
import tarfile
import wandb
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
from transformers import ViTFeatureExtractor
PRETRAIN_CHECKPOINT = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(PRETRAIN_CHECKPOINT)
MODEL = None
RESOLTUION = 224
labels = []
with open(r"labels.txt", "r") as fp:
for line in fp:
labels.append(line[:-1])
def normalize_img(
img, mean=feature_extractor.image_mean, std=feature_extractor.image_std
):
img = img / 255
mean = tf.constant(mean)
std = tf.constant(std)
return (img - mean) / std
def preprocess_input(image):
image = np.array(image)
image = tf.convert_to_tensor(image)
image = tf.image.resize(image, (RESOLTUION, RESOLTUION))
image = normalize_img(image)
image = tf.transpose(
image, (2, 0, 1)
) # Since HF models are channel-first.
return {
"pixel_values": tf.expand_dims(image, 0)
}
def get_predictions(wb_token, image):
global MODEL
if MODEL is None:
wandb.login(key=wb_token)
wandb.init(project="tfx-vit-pipeline", id="gvtyqdgn", resume=True)
path = wandb.use_artifact('tfx-vit-pipeline/final_model:1688113391', type='model').download()
tar = tarfile.open(f"{path}/model.tar.gz")
tar.extractall(path=".")
MODEL = tf.keras.models.load_model("./model")
preprocessed_image = preprocess_input(image)
prediction = MODEL.predict(preprocessed_image)
probs = tf.nn.softmax(prediction['logits'], axis=1)
confidences = {labels[i]: float(probs[0][i]) for i in range(3)}
return confidences
with gr.Blocks() as demo:
gr.Markdown("## Simple demo for a Image Classification of the Beans Dataset with HF ViT model")
wb_token_if = gr.Textbox(interactive=True, label="Your Weight & Biases API Key")
with gr.Row():
image_if = gr.Image()
label_if = gr.Label(num_top_classes=3)
classify_if = gr.Button()
classify_if.click(
get_predictions,
[wb_token_if, image_if],
label_if
)
demo.launch(debug=True)