import torch | |
import timm | |
import gradio as gr | |
""" | |
from ViT.ViT_new import vit_base_patch16_224 as vit | |
model = vit(pretrained=True).cuda() | |
model.eval() | |
model_finetuned = vit().cuda() | |
checkpoint = torch.load('ar_base.tar') | |
model_finetuned.load_state_dict(checkpoint['state_dict']) | |
model_finetuned.eval() | |
iface_orig = gr.Interface( | |
) | |
""" | |
def image_classifier(inp): | |
pass # image classifier model defined here | |
demo = gr.Interface(image_classifier, "image", "label") | |
demo.launch(share=True) |