| | import json |
| | import os |
| | from functools import lru_cache |
| | from typing import Mapping |
| |
|
| | import gradio as gr |
| | from huggingface_hub import HfFileSystem, hf_hub_download |
| | from imgutils.data import ImageTyping, load_image |
| | from natsort import natsorted |
| |
|
| | from onnx_ import _open_onnx_model |
| | from preprocess import _img_encode |
| |
|
| | hfs = HfFileSystem() |
| |
|
| |
|
| | @lru_cache() |
| | def open_model_from_repo(repository, model): |
| | runtime = _open_onnx_model(hf_hub_download(repository, f'{model}/model.onnx')) |
| | with open(hf_hub_download(repository, f'{model}/meta.json'), 'r') as f: |
| | labels = json.load(f)['labels'] |
| |
|
| | return runtime, labels |
| |
|
| |
|
| | class Classification: |
| | def __init__(self, title: str, repository: str, default_model=None, imgsize: int = 384): |
| | self.title = title |
| | self.repository = repository |
| | self.models = natsorted([ |
| | os.path.dirname(os.path.relpath(file, self.repository)) |
| | for file in hfs.glob(f'{self.repository}/*/model.onnx') |
| | ]) |
| | self.default_model = default_model or self.models[0] |
| | self.imgsize = imgsize |
| |
|
| | def _open_onnx_model(self, model): |
| | return open_model_from_repo(self.repository, model) |
| |
|
| | def _gr_classification(self, image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]: |
| | image = load_image(image, mode='RGB') |
| | input_ = _img_encode(image, size=(size, size))[None, ...] |
| | model, labels = self._open_onnx_model(model_name) |
| | output, = model.run(['output'], {'input': input_}) |
| |
|
| | values = dict(zip(labels, map(lambda x: x.item(), output[0]))) |
| | return values |
| |
|
| | def create_gr(self): |
| | with gr.Tab(self.title): |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr_input_image = gr.Image(type='pil', label='Original Image') |
| | gr_model = gr.Dropdown(self.models, value=self.default_model, label='Model') |
| | gr_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size') |
| | gr_submit = gr.Button(value='Submit', variant='primary') |
| |
|
| | with gr.Column(): |
| | gr_output = gr.Label(label='Classes') |
| |
|
| | gr_submit.click( |
| | self._gr_classification, |
| | inputs=[gr_input_image, gr_model, gr_infer_size], |
| | outputs=[gr_output], |
| | ) |
| |
|