Spaces:
Sleeping
Sleeping
Commit
·
93d2b0f
1
Parent(s):
306f681
Update app.py
Browse files
app.py
CHANGED
@@ -1,170 +1,56 @@
|
|
1 |
import os
|
2 |
-
import time
|
3 |
-
import argparse
|
4 |
-
import numpy as np
|
5 |
-
import random
|
6 |
import pandas as pd
|
7 |
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
import torch.nn.functional as F
|
10 |
-
import torchvision
|
11 |
-
import sys
|
12 |
-
import json
|
13 |
-
from collections import defaultdict
|
14 |
-
import math
|
15 |
import gradio as gr
|
16 |
-
|
17 |
from model import DistMult
|
18 |
-
|
19 |
-
from tqdm import tqdm
|
20 |
-
from utils import collate_list, detach_and_clone, move_to
|
21 |
from PIL import Image
|
22 |
from torchvision import transforms
|
|
|
|
|
23 |
|
|
|
24 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
|
25 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
torch.set_grad_enabled(False)
|
34 |
|
35 |
-
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
transform_steps = transforms.Compose([
|
|
|
|
|
|
|
|
|
40 |
h = transform_steps(img)
|
41 |
r = torch.tensor([3])
|
42 |
|
43 |
-
|
44 |
-
|
|
|
45 |
|
46 |
outputs = model.forward_ce(h, r, triple_type=('image', 'id'))
|
47 |
-
|
48 |
-
y_pred = detach_and_clone(outputs.cpu())
|
49 |
-
y_pred = y_pred.argmax(-1)
|
50 |
-
|
51 |
pred_label = target_list[y_pred].item()
|
52 |
species_label = overall_id_to_name[str(id2entity[pred_label])]
|
53 |
-
print('species label = {}'.format(species_label))
|
54 |
-
|
55 |
-
# predict multi-level classification
|
56 |
-
|
57 |
-
# def get_classification(img):
|
58 |
-
|
59 |
-
# image_tensor = transform_image(img)
|
60 |
-
# ort_inputs = {input_name: to_numpy(image_tensor)}
|
61 |
-
# x = ort_session.run(None, ort_inputs)
|
62 |
-
# predictions = torch.topk(torch.from_numpy(x[0]), k=5).indices.squeeze(0).tolist()
|
63 |
-
|
64 |
-
# result = {}
|
65 |
-
# for i in predictions:
|
66 |
-
# label = label_map[str(i)]
|
67 |
-
# prob = x[0][0, i].item()
|
68 |
-
# result[label] = prob
|
69 |
-
# return result
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
# iface.launch()
|
74 |
-
|
75 |
-
|
76 |
return species_label
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
categories = []
|
89 |
-
for item in tqdm(sub):
|
90 |
-
if entity2id[str(int(float(item)))] not in categories:
|
91 |
-
categories.append(entity2id[str(int(float(item)))])
|
92 |
-
# print('categories = {}'.format(categories))
|
93 |
-
# print("No. of target categories = {}".format(len(categories)))
|
94 |
-
return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
if __name__=='__main__':
|
99 |
-
parser = argparse.ArgumentParser()
|
100 |
-
# parser.add_argument('--data-dir', type=str, default='data/iwildcam_v2.0/')
|
101 |
-
# parser.add_argument('--img-path', type=str, required=True, help='path to species image to be classified')
|
102 |
-
parser.add_argument('--seed', type=int, default=813765)
|
103 |
-
parser.add_argument('--ckpt-path', type=str, default=None, help='path to ckpt for restarting expt')
|
104 |
-
parser.add_argument('--debug', action='store_true')
|
105 |
-
parser.add_argument('--no-cuda', action='store_true')
|
106 |
-
parser.add_argument('--batch_size', type=int, default=16)
|
107 |
-
|
108 |
-
parser.add_argument('--embedding-dim', type=int, default=512)
|
109 |
-
parser.add_argument('--location_input_dim', type=int, default=2)
|
110 |
-
parser.add_argument('--time_input_dim', type=int, default=1)
|
111 |
-
parser.add_argument('--mlp_location_numlayer', type=int, default=3)
|
112 |
-
parser.add_argument('--mlp_time_numlayer', type=int, default=3)
|
113 |
-
|
114 |
-
parser.add_argument('--img-embed-model', choices=['resnet18', 'resnet50'], default='resnet50')
|
115 |
-
parser.add_argument('--use-data-subset', action='store_true')
|
116 |
-
parser.add_argument('--subset-size', type=int, default=10)
|
117 |
-
|
118 |
-
args = parser.parse_args()
|
119 |
-
|
120 |
-
print('args = {}'.format(args))
|
121 |
-
args.device = torch.device('cuda') if not args.no_cuda and torch.cuda.is_available() else torch.device('cpu')
|
122 |
-
|
123 |
-
# Set random seed
|
124 |
-
torch.manual_seed(args.seed)
|
125 |
-
np.random.seed(args.seed)
|
126 |
-
random.seed(args.seed)
|
127 |
-
|
128 |
-
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
|
129 |
-
|
130 |
-
entity_id_file = 'entity2id_subtree.json'
|
131 |
-
|
132 |
-
if not os.path.exists(entity_id_file):
|
133 |
-
entity2id = {} # each of triple types have their own entity2id
|
134 |
-
|
135 |
-
for i in tqdm(range(datacsv.shape[0])):
|
136 |
-
if datacsv.iloc[i,1] == "id":
|
137 |
-
_get_id(entity2id, str(int(float(datacsv.iloc[i,0]))))
|
138 |
-
|
139 |
-
if datacsv.iloc[i,-2] == "id":
|
140 |
-
_get_id(entity2id, str(int(float(datacsv.iloc[i,-3]))))
|
141 |
-
json.dump(entity2id, open(entity_id_file, 'w'))
|
142 |
-
else:
|
143 |
-
entity2id = json.load(open(entity_id_file, 'r'))
|
144 |
-
|
145 |
-
id2entity = {v:k for k,v in entity2id.items()}
|
146 |
-
|
147 |
-
num_ent_id = len(entity2id)
|
148 |
-
|
149 |
-
# print('len(entity2id) = {}'.format(len(entity2id)))
|
150 |
-
|
151 |
-
target_list = generate_target_list(datacsv, entity2id)
|
152 |
-
|
153 |
-
# restore from ckpt
|
154 |
-
if args.ckpt_path:
|
155 |
-
ckpt = torch.load(args.ckpt_path, map_location=args.device)
|
156 |
-
model.load_state_dict(ckpt['model'], strict=False)
|
157 |
-
print('ckpt loaded...')
|
158 |
-
|
159 |
-
species_model = gr.Interface(
|
160 |
-
evaluate,
|
161 |
-
gr.inputs.Image(shape=(200, 200)),
|
162 |
-
outputs="label",
|
163 |
-
title = 'Species Classification',
|
164 |
-
description = 'Species Classification',
|
165 |
-
article = 'Species Classification',
|
166 |
-
name = 'Species Classification',
|
167 |
-
)
|
168 |
-
species_model.launch(server_port=7897)
|
169 |
-
|
170 |
-
# evaluate(model, id2entity, target_list, args)
|
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
2 |
import pandas as pd
|
3 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import gradio as gr
|
|
|
5 |
from model import DistMult
|
|
|
|
|
|
|
6 |
from PIL import Image
|
7 |
from torchvision import transforms
|
8 |
+
import json
|
9 |
+
from tqdm import tqdm
|
10 |
|
11 |
+
# Default image tensor normalization
|
12 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
|
13 |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
|
14 |
|
15 |
+
# Load necessary data and initialize the model
|
16 |
+
entity2id = json.load(open('entity2id_subtree.json', 'r'))
|
17 |
+
id2entity = {v: k for k, v in entity2id.items()}
|
18 |
+
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
|
19 |
+
num_ent_id = len(entity2id)
|
20 |
+
target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere
|
|
|
21 |
|
22 |
+
# Initialize your model here
|
23 |
+
model = DistMult(args, num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary
|
24 |
+
model.eval()
|
25 |
|
26 |
+
# Define your evaluation function
|
27 |
+
def evaluate(img):
|
28 |
+
transform_steps = transforms.Compose([
|
29 |
+
transforms.Resize((448, 448)),
|
30 |
+
transforms.ToTensor(),
|
31 |
+
transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
|
32 |
+
])
|
33 |
h = transform_steps(img)
|
34 |
r = torch.tensor([3])
|
35 |
|
36 |
+
# Assuming `move_to` is a function to move tensors to the desired device
|
37 |
+
h = h.unsqueeze(0)
|
38 |
+
r = r.unsqueeze(0)
|
39 |
|
40 |
outputs = model.forward_ce(h, r, triple_type=('image', 'id'))
|
41 |
+
y_pred = outputs.argmax(-1).cpu()
|
|
|
|
|
|
|
42 |
pred_label = target_list[y_pred].item()
|
43 |
species_label = overall_id_to_name[str(id2entity[pred_label])]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
|
|
|
|
|
|
|
|
|
|
45 |
return species_label
|
46 |
|
47 |
+
# Gradio interface
|
48 |
+
species_model = gr.Interface(
|
49 |
+
evaluate,
|
50 |
+
gr.inputs.Image(shape=(200, 200)),
|
51 |
+
outputs="label",
|
52 |
+
title='Species Classification',
|
53 |
+
description='Species Classification',
|
54 |
+
article='Species Classification'
|
55 |
+
)
|
56 |
+
species_model.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|