Junlinh commited on
Commit
86c62f2
·
1 Parent(s): 03afbc2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import torch
5
+ def predict(input_img):
6
+ input_img = Image.fromarray(np.uint8(input_img))
7
+ model1 = models.__dict__['resnet50'](num_classes=1)
8
+ model2 = models.__dict__['resnet50'](num_classes=1)
9
+
10
+ loc = 'cuda:{}'.format(0)
11
+ checkpoint1 = torch.load("./machine_full_best.tar", map_location=loc)
12
+ model1.load_state_dict(checkpoint1['state_dict'])
13
+ checkpoint2 = torch.load("./human_full_best.tar", map_location=loc)
14
+ model2.load_state_dict(checkpoint2['state_dict'])
15
+
16
+ my_transform = transforms.Compose([
17
+ transforms.RandomResizedCrop(224, (1, 1)),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
20
+ std=[0.229, 0.224, 0.225]),])
21
+
22
+ input_img = my_transform(input_img).view(1,3,224,224)
23
+ model1.eval()
24
+ model2.eval()
25
+ result1 = round(model1(input_img).item(), 3)
26
+ result2 = round(model2(input_img).item(), 3)
27
+ result = 'MachineMem score = ' + str(result1) + ', HumanMem score = ' + str(result2) +'.'
28
+ return result
29
+
30
+ demo = gr.Interface(predict, gr.Image(), "text")
31
+ demo.launch()