File size: 2,178 Bytes
1556762
845eb37
1556762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from torch import nn
from torchvision import transforms, models

class ActionClassifier(nn.Module):
    def __init__(self, ntargets):
        super().__init__()
        resnet = models.resnet50(pretrained=True, progress=True)
        modules = list(resnet.children())[:-1] # delete last layer
        self.resnet = nn.Sequential(*modules)
        for param in self.resnet.parameters():
            param.requires_grad = False
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.BatchNorm1d(resnet.fc.in_features),
            nn.Dropout(0.2),
            nn.Linear(resnet.fc.in_features, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.2),
            nn.Linear(256, ntargets)
        )
    
    def forward(self, x):
        x = self.resnet(x)
        x = self.fc(x)
        return x



def get_transform():
    transform = transforms.Compose([ 
                transforms.Resize([224, 244]), 
                transforms.ToTensor(),
                # std multiply by 255 to convert img of [0, 255]
                # to img of [0, 1]
                transforms.Normalize((0.485, 0.456, 0.406), 
                                     (0.229*255, 0.224*255, 0.225*255))]
            )
    return transform


def get_model():
    model = ActionClassifier(15)
    model.load_state_dict(torch.load('./classifier_weights.pth', map_location=torch.device('cpu')))
    return model


def get_class(index):
    ind2cat = [
        'calling',
        'clapping',
        'cycling',
        'dancing',
        'drinking',
        'eating',
        'fighting',
        'hugging',
        'laughing',
        'listening_to_music',
        'running',
        'sitting',
        'sleeping',
        'texting',
        'using_laptop'
    ]
    return ind2cat[index]




# img = Image.open('./inputs/Image_102.jpg').convert('RGB')
# #print(transform(img))
# img = transform(img)
# img = img.unsqueeze(dim=0)
# print(img.shape)






# model.eval()
# with torch.no_grad():
#     out = model(img)
#     out = nn.Softmax()(out).squeeze()
#     print(out.shape)
#     res = torch.argmax(out)
    
#     print(ind2cat[res])