Huy0502 commited on
Commit
82846d1
·
verified ·
1 Parent(s): a88cf03

Delete models.py

Browse files
Files changed (1) hide show
  1. models.py +0 -164
models.py DELETED
@@ -1,164 +0,0 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- from ultralytics import YOLO
5
- from transformers import AutoProcessor
6
- from transformers import AutoModelForTokenClassification
7
- from utils import normalize_box, unnormalize_box, draw_output, create_df
8
- from PIL import Image, ImageDraw
9
- from vietocr.tool.predictor import Predictor
10
- from vietocr.tool.config import Cfg
11
-
12
- class Reciept_Analyzer:
13
- def __init__(self,
14
- processor_pretrained='microsoft/layoutlmv3-base',
15
- layoutlm_pretrained=os.path.join(
16
- 'models', 'checkpoint'),
17
- yolo_pretrained=os.path.join(
18
- 'models', 'best.pt'),
19
- vietocr_pretrained=os.path.join(
20
- 'models', 'vietocr', 'vgg_seq2seq.pth')
21
- ):
22
-
23
- print("Initializing processor")
24
- self.processor = AutoProcessor.from_pretrained(
25
- processor_pretrained, apply_ocr=False)
26
- print("Finished initializing processor")
27
-
28
- print("Initializing LayoutLM model")
29
- self.lalm_model = AutoModelForTokenClassification.from_pretrained(
30
- layoutlm_pretrained)
31
- print("Finished initializing LayoutLM model")
32
-
33
- if yolo_pretrained is not None:
34
- print("Initializing YOLO model")
35
- self.yolo_model = YOLO(yolo_pretrained)
36
- print("Finished initializing YOLO model")
37
-
38
- print("Initializing VietOCR model")
39
- config = Cfg.load_config_from_name('vgg_seq2seq')
40
- config['weights'] = vietocr_pretrained
41
- config['cnn']['pretrained']= False
42
- config['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
43
- self.vietocr = Predictor(config)
44
- print("Finished initializing VietOCR model")
45
-
46
- def forward(self, img, output_path="output", is_save_cropped_img=False):
47
- input_image = Image.open(img)
48
-
49
- # detection with YOLOv8
50
- bboxes = self.yolov8_det(input_image)
51
-
52
- # sort
53
- sorted_bboxes = self.sort_bboxes(bboxes)
54
-
55
- # draw bbox
56
- image_draw = input_image.copy()
57
- self.draw_bbox(image_draw, sorted_bboxes, output_path)
58
-
59
- # crop images
60
- cropped_images, normalized_boxes = self.get_cropped_images(input_image, sorted_bboxes, is_save_cropped_img, output_path)
61
-
62
- # recognition with VietOCR
63
- texts, mapping_bbox_texts = self.ocr(cropped_images, normalized_boxes)
64
-
65
- # KIE with LayoutLMv3
66
- pred_texts, pred_label, boxes = self.kie(input_image, texts, normalized_boxes, mapping_bbox_texts, output_path)
67
-
68
- # create dataframe
69
- return create_df(pred_texts, pred_label)
70
-
71
-
72
- def yolov8_det(self, img):
73
- return self.yolo_model.predict(source=img, conf=0.3, iou=0.1)[0].boxes.xyxy.int()
74
-
75
- def sort_bboxes(self, bboxes):
76
- bbox_list = []
77
- for box in bboxes:
78
- tlx, tly, brx, bry = map(int, box)
79
- bbox_list.append([tlx, tly, brx, bry])
80
- bbox_list.sort(key=lambda x: (x[1], x[2]))
81
- return bbox_list
82
-
83
- def draw_bbox(self, image_draw, bboxes, output_path):
84
- # draw bbox
85
- draw = ImageDraw.Draw(image_draw)
86
- for box in bboxes:
87
- draw.rectangle(box, outline='red', width=2)
88
- image_draw.save(os.path.join(output_path, 'bbox.jpg'))
89
- print(f"Exported image with bounding boxes to {os.path.join(output_path, 'bbox.jpg')}")
90
-
91
- def get_cropped_images(self, input_image, bboxes, is_save_cropped=False, output_path="output"):
92
- normalized_boxes = []
93
- cropped_images = []
94
-
95
- # OCR
96
- if is_save_cropped:
97
- cropped_folder = os.path.join(output_path, "cropped")
98
- if not os.path.exists(cropped_folder):
99
- os.makedirs(cropped_folder)
100
- i = 0
101
- for box in bboxes:
102
- tlx, tly, brx, bry = map(int, box)
103
- normalized_box = normalize_box(box, input_image.width, input_image.height)
104
- normalized_boxes.append(normalized_box)
105
- cropped_ = input_image.crop((tlx, tly, brx, bry))
106
- if is_save_cropped:
107
- cropped_.save(os.path.join(cropped_folder, f'cropped_{i}.jpg'))
108
- i += 1
109
- cropped_images.append(cropped_)
110
-
111
- return cropped_images, normalized_boxes
112
-
113
- def ocr(self, cropped_images, normalized_boxes):
114
- mapping_bbox_texts = {}
115
- texts = []
116
- for img, normalized_box in zip(cropped_images, normalized_boxes):
117
- result = self.vietocr.predict(img)
118
- text = result.strip().replace('\n', ' ')
119
- texts.append(text)
120
- mapping_bbox_texts[','.join(map(str, normalized_box))] = text
121
-
122
- return texts, mapping_bbox_texts
123
-
124
- def kie(self, img, texts, boxes, mapping_bbox_texts, output_path):
125
- encoding = self.processor(img, texts,
126
- boxes=boxes,
127
- return_offsets_mapping=True,
128
- return_tensors='pt',
129
- max_length=512,
130
- padding='max_length')
131
- offset_mapping = encoding.pop('offset_mapping')
132
-
133
- with torch.no_grad():
134
- outputs = self.lalm_model(**encoding)
135
-
136
- id2label = self.lalm_model.config.id2label
137
- logits = outputs.logits
138
- token_boxes = encoding.bbox.squeeze().tolist()
139
- offset_mapping = offset_mapping.squeeze().tolist()
140
-
141
- predictions = logits.argmax(-1).squeeze().tolist()
142
- is_subword = np.array(offset_mapping)[:, 0] != 0
143
-
144
- true_predictions = []
145
- true_boxes = []
146
- true_texts = []
147
- for idx in range(len(predictions)):
148
- if not is_subword[idx] and token_boxes[idx] != [0, 0, 0, 0]:
149
- true_predictions.append(id2label[predictions[idx]])
150
- true_boxes.append(unnormalize_box(
151
- token_boxes[idx], img.width, img.height))
152
- true_texts.append(mapping_bbox_texts.get(
153
- ','.join(map(str, token_boxes[idx])), ''))
154
-
155
- if isinstance(output_path, str):
156
- os.makedirs(output_path, exist_ok=True)
157
- img_output = draw_output(
158
- image=img,
159
- true_predictions=true_predictions,
160
- true_boxes=true_boxes
161
- )
162
- img_output.save(os.path.join(output_path, 'result.jpg'))
163
- print(f"Exported result to {os.path.join(output_path, 'result.jpg')}")
164
- return true_texts, true_predictions, true_boxes