Junlinh commited on
Commit
c4b9028
·
1 Parent(s): 12b9840

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -2,10 +2,17 @@ import gradio as gr
2
  import torchvision.transforms as transforms
3
  from PIL import Image
4
  import torch
 
5
  def predict(input_img):
6
  input_img = Image.fromarray(np.uint8(input_img))
7
- model1 = models.__dict__['resnet50'](num_classes=1)
8
- model2 = models.__dict__['resnet50'](num_classes=1)
 
 
 
 
 
 
9
 
10
  loc = 'cuda:{}'.format(0)
11
  checkpoint1 = torch.load("./machine_full_best.tar", map_location=loc)
 
2
  import torchvision.transforms as transforms
3
  from PIL import Image
4
  import torch
5
+ from timm.models import create_model
6
  def predict(input_img):
7
  input_img = Image.fromarray(np.uint8(input_img))
8
+ model1 = create_model(
9
+ 'resnet50',
10
+ drop_rate=0.5,
11
+ num_classes=1,)
12
+ model2 = create_model(
13
+ 'resnet50',
14
+ drop_rate=0.5,
15
+ num_classes=1,)
16
 
17
  loc = 'cuda:{}'.format(0)
18
  checkpoint1 = torch.load("./machine_full_best.tar", map_location=loc)