vardaan123 commited on
Commit
93d2b0f
·
1 Parent(s): 306f681

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -147
app.py CHANGED
@@ -1,170 +1,56 @@
1
  import os
2
- import time
3
- import argparse
4
- import numpy as np
5
- import random
6
  import pandas as pd
7
  import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- import torchvision
11
- import sys
12
- import json
13
- from collections import defaultdict
14
- import math
15
  import gradio as gr
16
-
17
  from model import DistMult
18
-
19
- from tqdm import tqdm
20
- from utils import collate_list, detach_and_clone, move_to
21
  from PIL import Image
22
  from torchvision import transforms
 
 
23
 
 
24
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
25
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
26
 
27
- def evaluate(img):
28
- num_ent_id = len(id2entity)
29
- model = DistMult(args, num_ent_id, target_list, args.device)
30
- model.to(args.device)
31
- model.eval()
32
-
33
- torch.set_grad_enabled(False)
34
 
35
- overall_id_to_name = json.load(open('overall_id_to_name.json'))
 
 
36
 
37
- img = Image.open(args.img_path).convert('RGB')
38
-
39
- transform_steps = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)])
 
 
 
 
40
  h = transform_steps(img)
41
  r = torch.tensor([3])
42
 
43
- h = move_to(h, args.device).unsqueeze(0)
44
- r = move_to(r, args.device).unsqueeze(0)
 
45
 
46
  outputs = model.forward_ce(h, r, triple_type=('image', 'id'))
47
-
48
- y_pred = detach_and_clone(outputs.cpu())
49
- y_pred = y_pred.argmax(-1)
50
-
51
  pred_label = target_list[y_pred].item()
52
  species_label = overall_id_to_name[str(id2entity[pred_label])]
53
- print('species label = {}'.format(species_label))
54
-
55
- # predict multi-level classification
56
-
57
- # def get_classification(img):
58
-
59
- # image_tensor = transform_image(img)
60
- # ort_inputs = {input_name: to_numpy(image_tensor)}
61
- # x = ort_session.run(None, ort_inputs)
62
- # predictions = torch.topk(torch.from_numpy(x[0]), k=5).indices.squeeze(0).tolist()
63
-
64
- # result = {}
65
- # for i in predictions:
66
- # label = label_map[str(i)]
67
- # prob = x[0][0, i].item()
68
- # result[label] = prob
69
- # return result
70
 
71
-
72
-
73
- # iface.launch()
74
-
75
-
76
  return species_label
77
 
78
- def _get_id(dict, key):
79
- id = dict.get(key, None)
80
- if id is None:
81
- id = len(dict)
82
- dict[key] = id
83
- return id
84
-
85
- def generate_target_list(data, entity2id):
86
- sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']]
87
- sub = list(sub['t'])
88
- categories = []
89
- for item in tqdm(sub):
90
- if entity2id[str(int(float(item)))] not in categories:
91
- categories.append(entity2id[str(int(float(item)))])
92
- # print('categories = {}'.format(categories))
93
- # print("No. of target categories = {}".format(len(categories)))
94
- return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
95
-
96
-
97
-
98
- if __name__=='__main__':
99
- parser = argparse.ArgumentParser()
100
- # parser.add_argument('--data-dir', type=str, default='data/iwildcam_v2.0/')
101
- # parser.add_argument('--img-path', type=str, required=True, help='path to species image to be classified')
102
- parser.add_argument('--seed', type=int, default=813765)
103
- parser.add_argument('--ckpt-path', type=str, default=None, help='path to ckpt for restarting expt')
104
- parser.add_argument('--debug', action='store_true')
105
- parser.add_argument('--no-cuda', action='store_true')
106
- parser.add_argument('--batch_size', type=int, default=16)
107
-
108
- parser.add_argument('--embedding-dim', type=int, default=512)
109
- parser.add_argument('--location_input_dim', type=int, default=2)
110
- parser.add_argument('--time_input_dim', type=int, default=1)
111
- parser.add_argument('--mlp_location_numlayer', type=int, default=3)
112
- parser.add_argument('--mlp_time_numlayer', type=int, default=3)
113
-
114
- parser.add_argument('--img-embed-model', choices=['resnet18', 'resnet50'], default='resnet50')
115
- parser.add_argument('--use-data-subset', action='store_true')
116
- parser.add_argument('--subset-size', type=int, default=10)
117
-
118
- args = parser.parse_args()
119
-
120
- print('args = {}'.format(args))
121
- args.device = torch.device('cuda') if not args.no_cuda and torch.cuda.is_available() else torch.device('cpu')
122
-
123
- # Set random seed
124
- torch.manual_seed(args.seed)
125
- np.random.seed(args.seed)
126
- random.seed(args.seed)
127
-
128
- datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
129
-
130
- entity_id_file = 'entity2id_subtree.json'
131
-
132
- if not os.path.exists(entity_id_file):
133
- entity2id = {} # each of triple types have their own entity2id
134
-
135
- for i in tqdm(range(datacsv.shape[0])):
136
- if datacsv.iloc[i,1] == "id":
137
- _get_id(entity2id, str(int(float(datacsv.iloc[i,0]))))
138
-
139
- if datacsv.iloc[i,-2] == "id":
140
- _get_id(entity2id, str(int(float(datacsv.iloc[i,-3]))))
141
- json.dump(entity2id, open(entity_id_file, 'w'))
142
- else:
143
- entity2id = json.load(open(entity_id_file, 'r'))
144
-
145
- id2entity = {v:k for k,v in entity2id.items()}
146
-
147
- num_ent_id = len(entity2id)
148
-
149
- # print('len(entity2id) = {}'.format(len(entity2id)))
150
-
151
- target_list = generate_target_list(datacsv, entity2id)
152
-
153
- # restore from ckpt
154
- if args.ckpt_path:
155
- ckpt = torch.load(args.ckpt_path, map_location=args.device)
156
- model.load_state_dict(ckpt['model'], strict=False)
157
- print('ckpt loaded...')
158
-
159
- species_model = gr.Interface(
160
- evaluate,
161
- gr.inputs.Image(shape=(200, 200)),
162
- outputs="label",
163
- title = 'Species Classification',
164
- description = 'Species Classification',
165
- article = 'Species Classification',
166
- name = 'Species Classification',
167
- )
168
- species_model.launch(server_port=7897)
169
-
170
- # evaluate(model, id2entity, target_list, args)
 
1
  import os
 
 
 
 
2
  import pandas as pd
3
  import torch
 
 
 
 
 
 
 
4
  import gradio as gr
 
5
  from model import DistMult
 
 
 
6
  from PIL import Image
7
  from torchvision import transforms
8
+ import json
9
+ from tqdm import tqdm
10
 
11
+ # Default image tensor normalization
12
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
13
  _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
14
 
15
+ # Load necessary data and initialize the model
16
+ entity2id = json.load(open('entity2id_subtree.json', 'r'))
17
+ id2entity = {v: k for k, v in entity2id.items()}
18
+ datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
19
+ num_ent_id = len(entity2id)
20
+ target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere
 
21
 
22
+ # Initialize your model here
23
+ model = DistMult(args, num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary
24
+ model.eval()
25
 
26
+ # Define your evaluation function
27
+ def evaluate(img):
28
+ transform_steps = transforms.Compose([
29
+ transforms.Resize((448, 448)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
32
+ ])
33
  h = transform_steps(img)
34
  r = torch.tensor([3])
35
 
36
+ # Assuming `move_to` is a function to move tensors to the desired device
37
+ h = h.unsqueeze(0)
38
+ r = r.unsqueeze(0)
39
 
40
  outputs = model.forward_ce(h, r, triple_type=('image', 'id'))
41
+ y_pred = outputs.argmax(-1).cpu()
 
 
 
42
  pred_label = target_list[y_pred].item()
43
  species_label = overall_id_to_name[str(id2entity[pred_label])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
 
 
45
  return species_label
46
 
47
+ # Gradio interface
48
+ species_model = gr.Interface(
49
+ evaluate,
50
+ gr.inputs.Image(shape=(200, 200)),
51
+ outputs="label",
52
+ title='Species Classification',
53
+ description='Species Classification',
54
+ article='Species Classification'
55
+ )
56
+ species_model.launch()