engrharis commited on
Commit
7521288
·
verified ·
1 Parent(s): eea5ea0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023
3
+ Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora
4
+ GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
5
+ Project Website: https://abdur75648.github.io/UTRNet/
6
+ Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial
7
+ 4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/)
8
+ """
9
+
10
+ import os,shutil
11
+ import time
12
+ import argparse
13
+ import random
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ from datetime import datetime
17
+ import pytz
18
+
19
+ import torch
20
+ import torch.utils.data
21
+ import torch.nn.functional as F
22
+ from tqdm import tqdm
23
+ from nltk.metrics.distance import edit_distance
24
+
25
+ from utils import CTCLabelConverter, AttnLabelConverter, Averager, Logger
26
+ from dataset import hierarchical_dataset, AlignCollate
27
+ from model import Model
28
+
29
+ def validation(model, criterion, evaluation_loader, converter, opt, device):
30
+ """ validation or evaluation """
31
+ eval_arr = []
32
+ sum_len_gt = 0
33
+
34
+ n_correct = 0
35
+
36
+ norm_ED = 0
37
+ length_of_data = 0
38
+ infer_time = 0
39
+ valid_loss_avg = Averager()
40
+
41
+ for i, (image_tensors, labels) in enumerate(tqdm(evaluation_loader)):
42
+ batch_size = image_tensors.size(0)
43
+ length_of_data = length_of_data + batch_size
44
+ image = image_tensors.to(device)
45
+ # For max length prediction
46
+ length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
47
+ text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
48
+
49
+ text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
50
+
51
+ start_time = time.time()
52
+ if 'CTC' in opt.Prediction:
53
+ preds = model(image)
54
+ forward_time = time.time() - start_time
55
+ preds_size = torch.IntTensor([preds.size(1)] * batch_size)
56
+ cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
57
+ _, preds_index = preds.max(2)
58
+ preds_str = converter.decode(preds_index.data, preds_size.data)
59
+ else:
60
+ preds = model(image, text=text_for_pred, is_train=False)
61
+ forward_time = time.time() - start_time
62
+
63
+ preds = preds[:, :text_for_loss.shape[1] - 1, :].to(device)
64
+ target = text_for_loss[:, 1:].to(device) # without [GO] Symbol
65
+ cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
66
+ _, preds_index = preds.max(2)
67
+ preds_str = converter.decode(preds_index, length_for_pred)
68
+ labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
69
+
70
+ infer_time += forward_time
71
+ valid_loss_avg.add(cost)
72
+
73
+ # calculate accuracy & confidence score
74
+ preds_prob = F.softmax(preds, dim=2)
75
+ preds_max_prob, _ = preds_prob.max(dim=2)
76
+ confidence_score_list = []
77
+ for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
78
+ if 'Attn' in opt.Prediction:
79
+ gt = gt[:gt.find('[s]')]
80
+ pred_EOS = pred.find('[s]')
81
+ pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
82
+ pred_max_prob = pred_max_prob[:pred_EOS]
83
+
84
+ if pred == gt:
85
+ n_correct += 1
86
+
87
+ # ICDAR2019 Normalized Edit Distance
88
+ if len(gt) == 0 or len(pred) == 0:
89
+ ED = 0
90
+ elif len(gt) > len(pred):
91
+ ED = 1 - edit_distance(pred, gt) / len(gt)
92
+ else:
93
+ ED = 1 - edit_distance(pred, gt) / len(pred)
94
+
95
+ eval_arr.append([gt,pred,ED])
96
+
97
+ sum_len_gt += len(gt)
98
+ norm_ED += (ED*len(gt))
99
+
100
+ # calculate confidence score (= multiply of pred_max_prob)
101
+ try:
102
+ confidence_score = pred_max_prob.cumprod(dim=0)[-1]
103
+ except:
104
+ confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
105
+ confidence_score_list.append(confidence_score)
106
+ # print(pred, gt, pred==gt, confidence_score)
107
+
108
+ accuracy = n_correct / float(length_of_data) * 100
109
+ norm_ED = norm_ED / float(sum_len_gt)
110
+
111
+ return valid_loss_avg.val(), accuracy, norm_ED, eval_arr
112
+
113
+
114
+ def test(opt, device):
115
+ opt.device = device
116
+ os.makedirs("test_outputs", exist_ok=True)
117
+ datetime_now = str(datetime.now(pytz.timezone('Asia/Kolkata')).strftime("%Y-%m-%d_%H-%M-%S"))
118
+ logger = Logger(f'test_outputs/{datetime_now}.txt')
119
+ """ model configuration """
120
+ if 'CTC' in opt.Prediction:
121
+ converter = CTCLabelConverter(opt.character)
122
+ else:
123
+ converter = AttnLabelConverter(opt.character)
124
+ opt.num_class = len(converter.character)
125
+
126
+ if opt.rgb:
127
+ opt.input_channel = 3
128
+ model = Model(opt)
129
+ logger.log('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel,
130
+ opt.hidden_size, opt.num_class, opt.batch_max_length, opt.FeatureExtraction,
131
+ opt.SequenceModeling, opt.Prediction)
132
+ model = model.to(device)
133
+
134
+ # load model
135
+ model.load_state_dict(torch.load(opt.saved_model, map_location=device))
136
+ logger.log('Loaded pretrained model from %s' % opt.saved_model)
137
+ # logger.log(model)
138
+
139
+ """ setup loss """
140
+ if 'CTC' in opt.Prediction:
141
+ criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
142
+ else:
143
+ criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
144
+
145
+ """ evaluation """
146
+ model.eval()
147
+ with torch.no_grad():
148
+ AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW)#, keep_ratio_with_pad=opt.PAD)
149
+ eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt, rand_aug=False)
150
+ logger.log(eval_data_log)
151
+ evaluation_loader = torch.utils.data.DataLoader(
152
+ eval_data, batch_size=opt.batch_size,
153
+ shuffle=False,
154
+ num_workers=int(opt.workers),
155
+ collate_fn=AlignCollate_evaluation, pin_memory=True)
156
+ _, accuracy, norm_ED, eval_arr = validation( model, criterion, evaluation_loader, converter, opt,device)
157
+ logger.log("="*20)
158
+ logger.log(f'Accuracy : {accuracy:0.4f}\n')
159
+ logger.log(f'Norm_ED : {norm_ED:0.4f}\n')
160
+ logger.log("="*20)
161
+
162
+ if opt.visualize:
163
+ logger.log("Threshold - ", opt.threshold)
164
+ logger.log("ED","\t","gt","\t","pred")
165
+ arr = []
166
+ for gt,pred,ED in eval_arr:
167
+ ED = ED*100.0
168
+ arr.append(ED)
169
+ if ED<=(opt.threshold):
170
+ logger.log(ED,"\t",gt,"\t",pred)
171
+ plt.hist(arr, edgecolor="red")
172
+ plt.savefig('test_outputs/'+str(datetime_now)+".png")
173
+ plt.close()
174
+
175
+ if __name__ == '__main__':
176
+ parser = argparse.ArgumentParser()
177
+ parser.add_argument('--visualize', action='store_true', help='for visualization of bad samples')
178
+ parser.add_argument('--threshold', type=float, help='Save samples below this threshold in txt file', default=50.0)
179
+ parser.add_argument('--eval_data', required=True, help='path to evaluation dataset')
180
+ parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
181
+ parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
182
+ parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")
183
+ """ Data processing """
184
+ parser.add_argument('--batch_max_length', type=int, default=100, help='maximum-label-length')
185
+ parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
186
+ parser.add_argument('--imgW', type=int, default=400, help='the width of the input image')
187
+ parser.add_argument('--rgb', action='store_true', help='use rgb input')
188
+ """ Model Architecture """
189
+ parser.add_argument('--FeatureExtraction', type=str, default="HRNet", #required=True,
190
+ help='FeatureExtraction stage VGG|RCNN|ResNet|UNet|HRNet|Densenet|InceptionUnet|ResUnet|AttnUNet|UNet|VGG')
191
+ parser.add_argument('--SequenceModeling', type=str, default="DBiLSTM", #required=True,
192
+ help='SequenceModeling stage LSTM|GRU|MDLSTM|BiLSTM|DBiLSTM')
193
+ parser.add_argument('--Prediction', type=str, default="CTC", #required=True,
194
+ help='Prediction stage CTC|Attn')
195
+ parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
196
+ parser.add_argument('--output_channel', type=int, default=512, help='the number of output channel of Feature extractor')
197
+ parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
198
+ """ GPU Selection """
199
+ parser.add_argument('--device_id', type=str, default=None, help='cuda device ID')
200
+
201
+ opt = parser.parse_args()
202
+ if opt.FeatureExtraction == "HRNet":
203
+ opt.output_channel = 32
204
+
205
+ # Fix random seeds for both numpy and pytorch
206
+ seed = 1111
207
+ torch.manual_seed(seed)
208
+ torch.cuda.manual_seed(seed)
209
+ np.random.seed(seed)
210
+ random.seed(seed)
211
+ torch.backends.cudnn.deterministic = True
212
+ torch.backends.cudnn.benchmark = False
213
+
214
+ """ vocab / character number configuration """
215
+ file = open("UrduGlyphs.txt","r",encoding="utf-8")
216
+ content = file.readlines()
217
+ content = ''.join([str(elem).strip('\n') for elem in content])
218
+ opt.character = content+" "
219
+
220
+ cuda_str = 'cuda'
221
+ if opt.device_id is not None:
222
+ cuda_str = f'cuda:{opt.device_id}'
223
+ device = torch.device(cuda_str if torch.cuda.is_available() else 'cpu')
224
+ print("Device : ", device)
225
+
226
+ # opt.eval_data = "/DATA/parseq/val/"
227
+ # test(opt, device)
228
+
229
+ # opt.eval_data = "/DATA/parseq/IIITH/lmdb_new/"
230
+ # test(opt, device)
231
+
232
+ # opt.eval_data = "/DATA/public_datasets/UPTI/valid/"
233
+ # test(opt, device)
234
+
235
+ test(opt, device)