RobustViT / app.py
Hila's picture
add initial app.py
e0d8c59
raw
history blame
492 Bytes
import torch
import timm
import gradio as gr
"""
from ViT.ViT_new import vit_base_patch16_224 as vit
model = vit(pretrained=True).cuda()
model.eval()
model_finetuned = vit().cuda()
checkpoint = torch.load('ar_base.tar')
model_finetuned.load_state_dict(checkpoint['state_dict'])
model_finetuned.eval()
iface_orig = gr.Interface(
)
"""
def image_classifier(inp):
pass # image classifier model defined here
demo = gr.Interface(image_classifier, "image", "label")
demo.launch(share=True)