davanstrien's picture
davanstrien HF Staff
Update app.py
29dc431
raw
history blame
1.44 kB
from pathlib import Path
import gradio as gr
import torch
from transformers import AutoModelForImageClassification
import shutil
from optimum.pipelines import pipeline
device = 1 if torch.cuda.is_available() else "cpu"
chk_point = "davanstrien/autotrain-ia-useful-covers-3665397856"
model = AutoModelForImageClassification.from_pretrained(chk_point)
try:
pipe = pipeline(
"image-classification",
chk_point,
accelerator="bettertransformer",device=device
)
except NotImplementedError:
from transformers import pipeline
pipe = pipeline(
"image-classification",
chk_point,
device=device
)
def make_label_folders():
folders = model.config.label2id.keys()
for folder in folders:
folder = Path(folder)
if not folder.exists():
folder.mkdir()
return folders
def predictions_into_folders(files):
files = [file.name for file in files]
folders = make_label_folders()
predictions = pipe(files)
for file, prediction in zip(files, predictions):
label = prediction[0]["label"]
file_name = Path(file).name
shutil.copy(file, f"{label}/{file_name}")
for folder in folders:
shutil.make_archive(folder, "zip", ".", folder)
return [f"{folder}.zip" for folder in folders]
demo = gr.Interface(
predictions_into_folders,
gr.File(file_count="directory"),
gr.Files(),
cache_examples=True,
)
demo.launch()