File size: 1,890 Bytes
c8a491f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification, Trainer, TrainingArguments
import torch
import os

# 讟讜注谞讬诐 讚讗讟讗住讟 诪讛转讬拽讬讜转
dataset = load_dataset("imagefolder", data_dir=".", split={"train": "train[:80%]", "test": "train[80%:]"})

# 讘讜讞专讬诐 诪讜讚诇 讘住讬住讬
checkpoint = "google/vit-tiny-patch16-224"
processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=3,
    id2label={0: "rock", 1: "paper", 2: "scissors"},
    label2id={"rock": 0, "paper": 1, "scissors": 2}
)

# 驻讜谞拽爪讬讛 诇注讬讘讜讚 讛转诪讜谞讜转
def preprocess(examples):
    images = [x.convert("RGB") for x in examples["image"]]
    inputs = processor(images=images, return_tensors="pt")
    inputs["labels"] = examples["label"]
    return inputs

dataset = dataset.map(preprocess, batched=True)

# 讛讙讚专讜转 讗讬诪讜谉
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    load_best_model_at_end=True,
    logging_dir='./logs',
    logging_steps=5,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
)

# 讗讬诪讜谉
trainer.train()

# 驻讜谞拽爪讬讛 诇讛专爪转 讞讬讝讜讬 注诇 转诪讜谞讛 讞讚砖讛
def predict(image):
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    label = model.config.id2label[predicted_class_idx]
    return label

# 讘谞讬讬转 讗驻诇讬拽爪讬讛
demo = gr.Interface(fn=predict, inputs="image", outputs="text")
demo.launch()