chansung commited on
Commit
a84e904
·
1 Parent(s): 162393d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
@@ -42,7 +45,16 @@ def preprocess_input(image):
42
  "pixel_values": tf.expand_dims(image, 0)
43
  }
44
 
45
- def get_predictions(image):
 
 
 
 
 
 
 
 
 
46
  preprocessed_image = preprocess_input(image)
47
  prediction = MODEL.predict(preprocessed_image)
48
  probs = tf.nn.softmax(prediction['logits'], axis=1)
@@ -57,10 +69,16 @@ with gr.Blocks() as demo:
57
 
58
  with gr.Row():
59
  image_if = gr.Image()
60
- label_if = gr.Label()
61
 
62
  classify_if = gr.Button()
63
 
 
 
 
 
 
 
64
  # demo = gr.Interface(
65
  # get_predictions,
66
  # gr.inputs.Image(),
 
1
+ import tarfile
2
+ import wandb
3
+
4
  import gradio as gr
5
  import numpy as np
6
  from PIL import Image
 
45
  "pixel_values": tf.expand_dims(image, 0)
46
  }
47
 
48
+ def get_predictions(wb_token, image):
49
+ if MODEL is None:
50
+ wandb.login(key=wb_token)
51
+ path = wandb.use_artifact('tfx-vit-pipeline/final_model:latest', type='model').download()
52
+
53
+ tar = tarfile.open(f"{path}/model.tar.gz")
54
+ tar.extractall(path=".")
55
+
56
+ MODEL = tf.keras.models.load_model("./model")
57
+
58
  preprocessed_image = preprocess_input(image)
59
  prediction = MODEL.predict(preprocessed_image)
60
  probs = tf.nn.softmax(prediction['logits'], axis=1)
 
69
 
70
  with gr.Row():
71
  image_if = gr.Image()
72
+ label_if = gr.Label(num_top_classes=3)
73
 
74
  classify_if = gr.Button()
75
 
76
+ classify_if.click(
77
+ get_predictions,
78
+ [wb_token_if, image_if],
79
+ label_if
80
+ )
81
+
82
  # demo = gr.Interface(
83
  # get_predictions,
84
  # gr.inputs.Image(),