Junlinh commited on
Commit
a457df2
·
1 Parent(s): bbbc639

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
- from timm.models import create_model
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import torch
6
  def predict(input_img):
7
  input_img = Image.fromarray(np.uint8(input_img))
8
- model1 = models.__dict__['resnet50'](drop_rate=0.5,num_classes=1,)
9
- model2 = models.__dict__['resnet50'](drop_rate=0.5,num_classes=1,)
10
 
11
  loc = 'cuda:{}'.format(0)
12
  checkpoint1 = torch.load("./machine_full_best.tar", map_location=loc)
 
1
  import gradio as gr
 
2
  import torchvision.transforms as transforms
3
  from PIL import Image
4
  import torch
5
  def predict(input_img):
6
  input_img = Image.fromarray(np.uint8(input_img))
7
+ model1 = models.__dict__['resnet50'](num_classes=1)
8
+ model2 = models.__dict__['resnet50'](num_classes=1)
9
 
10
  loc = 'cuda:{}'.format(0)
11
  checkpoint1 = torch.load("./machine_full_best.tar", map_location=loc)