Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset | |
from transformers import AutoImageProcessor, AutoModelForImageClassification, Trainer, TrainingArguments | |
import torch | |
import os | |
# 讟讜注谞讬诐 讚讗讟讗住讟 诪讛转讬拽讬讜转 | |
dataset = load_dataset("imagefolder", data_dir=".", split={"train": "train[:80%]", "test": "train[80%:]"}) | |
# 讘讜讞专讬诐 诪讜讚诇 讘住讬住讬 | |
checkpoint = "google/vit-tiny-patch16-224" | |
processor = AutoImageProcessor.from_pretrained(checkpoint) | |
model = AutoModelForImageClassification.from_pretrained( | |
checkpoint, | |
num_labels=3, | |
id2label={0: "rock", 1: "paper", 2: "scissors"}, | |
label2id={"rock": 0, "paper": 1, "scissors": 2} | |
) | |
# 驻讜谞拽爪讬讛 诇注讬讘讜讚 讛转诪讜谞讜转 | |
def preprocess(examples): | |
images = [x.convert("RGB") for x in examples["image"]] | |
inputs = processor(images=images, return_tensors="pt") | |
inputs["labels"] = examples["label"] | |
return inputs | |
dataset = dataset.map(preprocess, batched=True) | |
# 讛讙讚专讜转 讗讬诪讜谉 | |
training_args = TrainingArguments( | |
output_dir="./results", | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
per_device_train_batch_size=4, | |
per_device_eval_batch_size=4, | |
num_train_epochs=5, | |
load_best_model_at_end=True, | |
logging_dir='./logs', | |
logging_steps=5, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset["train"], | |
eval_dataset=dataset["test"], | |
) | |
# 讗讬诪讜谉 | |
trainer.train() | |
# 驻讜谞拽爪讬讛 诇讛专爪转 讞讬讝讜讬 注诇 转诪讜谞讛 讞讚砖讛 | |
def predict(image): | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
label = model.config.id2label[predicted_class_idx] | |
return label | |
# 讘谞讬讬转 讗驻诇讬拽爪讬讛 | |
demo = gr.Interface(fn=predict, inputs="image", outputs="text") | |
demo.launch() | |