import os import tarfile import wandb import gradio as gr import numpy as np from PIL import Image import tensorflow as tf from transformers import ViTFeatureExtractor PRETRAIN_CHECKPOINT = "google/vit-base-patch16-224-in21k" feature_extractor = ViTFeatureExtractor.from_pretrained(PRETRAIN_CHECKPOINT) WB_KEY = os.environ['WB_KEY'] MODEL = None RESOLTUION = 224 labels = [] with open(r"labels.txt", "r") as fp: for line in fp: labels.append(line[:-1]) def normalize_img( img, mean=feature_extractor.image_mean, std=feature_extractor.image_std ): img = img / 255 mean = tf.constant(mean) std = tf.constant(std) return (img - mean) / std def preprocess_input(image): image = np.array(image) image = tf.convert_to_tensor(image) image = tf.image.resize(image, (RESOLTUION, RESOLTUION)) image = normalize_img(image) image = tf.transpose( image, (2, 0, 1) ) # Since HF models are channel-first. return { "pixel_values": tf.expand_dims(image, 0) } def get_predictions(image): global MODEL if MODEL is None: wandb.login(key=WB_KEY) wandb.init(project="tfx-vit-pipeline", id="gvtyqdgn", resume=True) path = wandb.use_artifact('tfx-vit-pipeline/final_model:1688113391', type='model').download() tar = tarfile.open(f"{path}/model.tar.gz") tar.extractall(path=".") MODEL = tf.keras.models.load_model("./model") preprocessed_image = preprocess_input(image) prediction = MODEL.predict(preprocessed_image) probs = tf.nn.softmax(prediction['logits'], axis=1) confidences = {labels[i]: float(probs[0][i]) for i in range(3)} return confidences with gr.Blocks() as demo: gr.Markdown("## Simple demo for a Image Classification of the Beans Dataset with HF ViT model") with gr.Row(): image_if = gr.Image() label_if = gr.Label(num_top_classes=3) classify_if = gr.Button() classify_if.click( get_predictions, image_if, label_if ) gr.Examples( [["test_image1.jpeg"], ["test_image2.jpeg"], ["test_image3.jpeg"]], [image_if], [label_if], get_predictions, cache_examples=True ) demo.launch(debug=True)