Spaces:
Sleeping
Sleeping
File size: 2,272 Bytes
60451f1 0da183e 60451f1 0da183e 60451f1 0da183e 60451f1 3c8d862 0da183e 60451f1 0da183e 3df3cdb 0da183e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import tarfile
import wandb
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
from transformers import ViTFeatureExtractor
PRETRAIN_CHECKPOINT = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(PRETRAIN_CHECKPOINT)
WB_KEY = os.environ['WB_KEY']
MODEL = None
RESOLTUION = 224
labels = []
with open(r"labels.txt", "r") as fp:
for line in fp:
labels.append(line[:-1])
def normalize_img(
img, mean=feature_extractor.image_mean, std=feature_extractor.image_std
):
img = img / 255
mean = tf.constant(mean)
std = tf.constant(std)
return (img - mean) / std
def preprocess_input(image):
image = np.array(image)
image = tf.convert_to_tensor(image)
image = tf.image.resize(image, (RESOLTUION, RESOLTUION))
image = normalize_img(image)
image = tf.transpose(
image, (2, 0, 1)
) # Since HF models are channel-first.
return {
"pixel_values": tf.expand_dims(image, 0)
}
def get_predictions(image):
global MODEL
if MODEL is None:
wandb.login(key=WB_KEY)
wandb.init(project="tfx-vit-pipeline", id="gvtyqdgn", resume=True)
path = wandb.use_artifact('tfx-vit-pipeline/final_model:1688113391', type='model').download()
tar = tarfile.open(f"{path}/model.tar.gz")
tar.extractall(path=".")
MODEL = tf.keras.models.load_model("./model")
preprocessed_image = preprocess_input(image)
prediction = MODEL.predict(preprocessed_image)
probs = tf.nn.softmax(prediction['logits'], axis=1)
confidences = {labels[i]: float(probs[0][i]) for i in range(3)}
return confidences
with gr.Blocks() as demo:
gr.Markdown("## Simple demo for a Image Classification of the Beans Dataset with HF ViT model")
with gr.Row():
image_if = gr.Image()
label_if = gr.Label(num_top_classes=3)
classify_if = gr.Button()
classify_if.click(
get_predictions,
image_if,
label_if
)
gr.Examples(
[["test_image1.jpeg"], ["test_image2.jpeg"], ["test_image3.jpeg"]],
[image_if],
[label_if],
get_predictions,
cache_examples=True
)
demo.launch(debug=True) |