vardaan123 commited on
Commit
4155f8b
·
1 Parent(s): 2f9a815

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -24,8 +24,12 @@ from torchvision import transforms
24
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
25
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
26
 
27
- def evaluate(img, model, id2entity, target_list, args):
 
 
 
28
  model.eval()
 
29
  torch.set_grad_enabled(False)
30
 
31
  overall_id_to_name = json.load(open('overall_id_to_name.json'))
@@ -146,10 +150,6 @@ if __name__=='__main__':
146
 
147
  target_list = generate_target_list(datacsv, entity2id)
148
 
149
- model = DistMult(args, num_ent_id, target_list, args.device)
150
-
151
- model.to(args.device)
152
-
153
  # restore from ckpt
154
  if args.ckpt_path:
155
  ckpt = torch.load(args.ckpt_path, map_location=args.device)
@@ -158,7 +158,7 @@ if __name__=='__main__':
158
 
159
  species_model = gr.Interface(
160
  evaluate,
161
- [gr.inputs.Image(shape=(200, 200)), model, id2entity, target_list, args],
162
  outputs="label",
163
  title = 'Species Classification',
164
  )
 
24
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
25
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
26
 
27
+ def evaluate(img, id2entity, target_list, args):
28
+ num_ent_id = len(id2entity)
29
+ model = DistMult(args, num_ent_id, target_list, args.device)
30
+ model.to(args.device)
31
  model.eval()
32
+
33
  torch.set_grad_enabled(False)
34
 
35
  overall_id_to_name = json.load(open('overall_id_to_name.json'))
 
150
 
151
  target_list = generate_target_list(datacsv, entity2id)
152
 
 
 
 
 
153
  # restore from ckpt
154
  if args.ckpt_path:
155
  ckpt = torch.load(args.ckpt_path, map_location=args.device)
 
158
 
159
  species_model = gr.Interface(
160
  evaluate,
161
+ [gr.inputs.Image(shape=(200, 200)), id2entity, target_list, args],
162
  outputs="label",
163
  title = 'Species Classification',
164
  )