File size: 2,015 Bytes
fb1f781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
from torchvision import transforms, models
from huggingface_hub import hf_hub_download
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes_school = 26  
num_classes_type = 10     

model_path = hf_hub_download(
    repo_id="Irina1402/mobilnetv3-painting-classification",  
    filename="model.pth"
)

model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT)
num_features = model.classifier[0].in_features
model.classifier = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, num_classes_school + num_classes_type)
)

model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

school_labels = [
    "American", "Austrian", "Belgian", "Bohemian", "Catalan", "Danish", "Dutch", "English", "Finnish",
    "Flemish", "French", "German", "Greek", "Hungarian", "Irish", "Italian", "Netherlandish", "Norwegian",
    "Other", "Polish", "Portuguese", "Russian", "Scottish", "Spanish", "Swedish", "Swiss"
]

type_labels = [
    "genre", "historical", "interior", "landscape", "mythological", "other",
    "portrait", "religious", "still-life", "study"
]

def classify_image(image: Image.Image):
    """Classify the uploaded image and return type and school predictions."""
    input_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)

    school_output = output[:, :num_classes_school]
    type_output = output[:, num_classes_school:]

    school_prediction = torch.argmax(school_output).item()
    type_prediction = torch.argmax(type_output).item()

    return {
        "school": school_labels[school_prediction],
        "type": type_labels[type_prediction]
    }