Spaces:
Sleeping
Sleeping
Commit
·
4155f8b
1
Parent(s):
2f9a815
Update app.py
Browse files
app.py
CHANGED
@@ -24,8 +24,12 @@ from torchvision import transforms
|
|
24 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
|
25 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
|
26 |
|
27 |
-
def evaluate(img,
|
|
|
|
|
|
|
28 |
model.eval()
|
|
|
29 |
torch.set_grad_enabled(False)
|
30 |
|
31 |
overall_id_to_name = json.load(open('overall_id_to_name.json'))
|
@@ -146,10 +150,6 @@ if __name__=='__main__':
|
|
146 |
|
147 |
target_list = generate_target_list(datacsv, entity2id)
|
148 |
|
149 |
-
model = DistMult(args, num_ent_id, target_list, args.device)
|
150 |
-
|
151 |
-
model.to(args.device)
|
152 |
-
|
153 |
# restore from ckpt
|
154 |
if args.ckpt_path:
|
155 |
ckpt = torch.load(args.ckpt_path, map_location=args.device)
|
@@ -158,7 +158,7 @@ if __name__=='__main__':
|
|
158 |
|
159 |
species_model = gr.Interface(
|
160 |
evaluate,
|
161 |
-
[gr.inputs.Image(shape=(200, 200)),
|
162 |
outputs="label",
|
163 |
title = 'Species Classification',
|
164 |
)
|
|
|
24 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
|
25 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
|
26 |
|
27 |
+
def evaluate(img, id2entity, target_list, args):
|
28 |
+
num_ent_id = len(id2entity)
|
29 |
+
model = DistMult(args, num_ent_id, target_list, args.device)
|
30 |
+
model.to(args.device)
|
31 |
model.eval()
|
32 |
+
|
33 |
torch.set_grad_enabled(False)
|
34 |
|
35 |
overall_id_to_name = json.load(open('overall_id_to_name.json'))
|
|
|
150 |
|
151 |
target_list = generate_target_list(datacsv, entity2id)
|
152 |
|
|
|
|
|
|
|
|
|
153 |
# restore from ckpt
|
154 |
if args.ckpt_path:
|
155 |
ckpt = torch.load(args.ckpt_path, map_location=args.device)
|
|
|
158 |
|
159 |
species_model = gr.Interface(
|
160 |
evaluate,
|
161 |
+
[gr.inputs.Image(shape=(200, 200)), id2entity, target_list, args],
|
162 |
outputs="label",
|
163 |
title = 'Species Classification',
|
164 |
)
|