Huy0502 commited on
Commit
8aca528
·
verified ·
1 Parent(s): 2ce2db1

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +67 -0
  2. models.py +169 -0
  3. utils.py +179 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from models import Reciept_Analyzer
4
+ from utils import find_product, get_info
5
+ import os
6
+ model = Reciept_Analyzer()
7
+
8
+ sample_images = []
9
+ for img_file in os.listdir("samples/"):
10
+ sample_images.append(os.path.join("samples", img_file))
11
+
12
+ def predict(image):
13
+ results = model.forward(image)
14
+ return results
15
+
16
+
17
+
18
+ # Thiết kế giao diện với Gradio
19
+ def create_interface():
20
+ with gr.Blocks() as app:
21
+ gr.Markdown("# Ứng dụng phân tích hóa đơn siêu thị")
22
+
23
+ with gr.Row():
24
+ # Cột bên trái
25
+ with gr.Column():
26
+ gr.Markdown("### Tải lên hóa đơn hoặc chọn ảnh mẫu")
27
+ image_input = gr.Image(label="Ảnh hóa đơn", type="filepath")
28
+
29
+
30
+
31
+ res = None
32
+ def on_image_selected(image_path):
33
+ global res
34
+ res = predict(image_path)
35
+ final = get_info(res)
36
+ print(res)
37
+ return final
38
+
39
+ def handle_input(item_name):
40
+ global res
41
+ result = find_product(item_name, res)
42
+ return result
43
+
44
+
45
+ gr.Markdown("### Ảnh mẫu")
46
+ example = gr.Examples(
47
+ inputs=image_input,
48
+ examples=sample_images
49
+ )
50
+
51
+ # Cột bên phải
52
+ with gr.Column():
53
+ result_output = gr.Textbox(label="Kết quả phân tích")
54
+ image_input.change(fn=on_image_selected, inputs=image_input, outputs=result_output)
55
+ gr.Markdown("### Tìm kiếm thông tin item")
56
+ item_input = gr.Textbox(label="Tên item cần tìm")
57
+ output = gr.Textbox(label="Kết quả")
58
+
59
+ search_button = gr.Button("Tìm kiếm")
60
+ search_button.click(fn=handle_input, inputs=item_input, outputs=output)
61
+
62
+ return app
63
+
64
+
65
+ # Chạy ứng dụng
66
+ app = create_interface()
67
+ app.launch()
models.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from ultralytics import YOLO
5
+ from transformers import AutoProcessor
6
+ from transformers import AutoModelForTokenClassification
7
+ from utils import normalize_box, unnormalize_box, draw_output, create_df
8
+ from PIL import Image, ImageDraw
9
+ from vietocr.tool.predictor import Predictor
10
+ from vietocr.tool.config import Cfg
11
+
12
+ class Reciept_Analyzer:
13
+ def __init__(self,
14
+ processor_pretrained='microsoft/layoutlmv3-base',
15
+ layoutlm_pretrained=os.path.join(
16
+ 'models', 'checkpoint'),
17
+ yolo_pretrained=os.path.join(
18
+ 'models', 'best.pt'),
19
+ vietocr_pretrained=os.path.join(
20
+ 'models', 'vietocr', 'vgg_seq2seq.pth')
21
+ ):
22
+
23
+ print("Initializing processor")
24
+ if torch.cuda.is_available():
25
+ print("Using GPU")
26
+ else:
27
+ print("No GPU detected, using CPU")
28
+
29
+ self.processor = AutoProcessor.from_pretrained(
30
+ processor_pretrained, apply_ocr=False)
31
+ print("Finished initializing processor")
32
+
33
+ print("Initializing LayoutLM model")
34
+ self.lalm_model = AutoModelForTokenClassification.from_pretrained(
35
+ layoutlm_pretrained)
36
+ print("Finished initializing LayoutLM model")
37
+
38
+ if yolo_pretrained is not None:
39
+ print("Initializing YOLO model")
40
+ self.yolo_model = YOLO(yolo_pretrained)
41
+ print("Finished initializing YOLO model")
42
+
43
+ print("Initializing VietOCR model")
44
+ config = Cfg.load_config_from_name('vgg_seq2seq')
45
+ config['weights'] = vietocr_pretrained
46
+ config['cnn']['pretrained']= False
47
+ config['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
48
+ self.vietocr = Predictor(config)
49
+ print("Finished initializing VietOCR model")
50
+
51
+ def forward(self, img, output_path="output", is_save_cropped_img=False):
52
+ input_image = Image.open(img)
53
+
54
+ # detection with YOLOv8
55
+ bboxes = self.yolov8_det(input_image)
56
+
57
+ # sort
58
+ sorted_bboxes = self.sort_bboxes(bboxes)
59
+
60
+ # draw bbox
61
+ image_draw = input_image.copy()
62
+ self.draw_bbox(image_draw, sorted_bboxes, output_path)
63
+
64
+ # crop images
65
+ cropped_images, normalized_boxes = self.get_cropped_images(input_image, sorted_bboxes, is_save_cropped_img, output_path)
66
+
67
+ # recognition with VietOCR
68
+ texts, mapping_bbox_texts = self.ocr(cropped_images, normalized_boxes)
69
+
70
+ # KIE with LayoutLMv3
71
+ pred_texts, pred_label, boxes = self.kie(input_image, texts, normalized_boxes, mapping_bbox_texts, output_path)
72
+
73
+ # create dataframe
74
+ return create_df(pred_texts, pred_label)
75
+
76
+
77
+ def yolov8_det(self, img):
78
+ return self.yolo_model.predict(source=img, conf=0.3, iou=0.1)[0].boxes.xyxy.int()
79
+
80
+ def sort_bboxes(self, bboxes):
81
+ bbox_list = []
82
+ for box in bboxes:
83
+ tlx, tly, brx, bry = map(int, box)
84
+ bbox_list.append([tlx, tly, brx, bry])
85
+ bbox_list.sort(key=lambda x: (x[1], x[2]))
86
+ return bbox_list
87
+
88
+ def draw_bbox(self, image_draw, bboxes, output_path):
89
+ # draw bbox
90
+ draw = ImageDraw.Draw(image_draw)
91
+ for box in bboxes:
92
+ draw.rectangle(box, outline='red', width=2)
93
+ image_draw.save(os.path.join(output_path, 'bbox.jpg'))
94
+ print(f"Exported image with bounding boxes to {os.path.join(output_path, 'bbox.jpg')}")
95
+
96
+ def get_cropped_images(self, input_image, bboxes, is_save_cropped=False, output_path="output"):
97
+ normalized_boxes = []
98
+ cropped_images = []
99
+
100
+ # OCR
101
+ if is_save_cropped:
102
+ cropped_folder = os.path.join(output_path, "cropped")
103
+ if not os.path.exists(cropped_folder):
104
+ os.makedirs(cropped_folder)
105
+ i = 0
106
+ for box in bboxes:
107
+ tlx, tly, brx, bry = map(int, box)
108
+ normalized_box = normalize_box(box, input_image.width, input_image.height)
109
+ normalized_boxes.append(normalized_box)
110
+ cropped_ = input_image.crop((tlx, tly, brx, bry))
111
+ if is_save_cropped:
112
+ cropped_.save(os.path.join(cropped_folder, f'cropped_{i}.jpg'))
113
+ i += 1
114
+ cropped_images.append(cropped_)
115
+
116
+ return cropped_images, normalized_boxes
117
+
118
+ def ocr(self, cropped_images, normalized_boxes):
119
+ mapping_bbox_texts = {}
120
+ texts = []
121
+ for img, normalized_box in zip(cropped_images, normalized_boxes):
122
+ result = self.vietocr.predict(img)
123
+ text = result.strip().replace('\n', ' ')
124
+ texts.append(text)
125
+ mapping_bbox_texts[','.join(map(str, normalized_box))] = text
126
+
127
+ return texts, mapping_bbox_texts
128
+
129
+ def kie(self, img, texts, boxes, mapping_bbox_texts, output_path):
130
+ encoding = self.processor(img, texts,
131
+ boxes=boxes,
132
+ return_offsets_mapping=True,
133
+ return_tensors='pt',
134
+ max_length=512,
135
+ padding='max_length')
136
+ offset_mapping = encoding.pop('offset_mapping')
137
+
138
+ with torch.no_grad():
139
+ outputs = self.lalm_model(**encoding)
140
+
141
+ id2label = self.lalm_model.config.id2label
142
+ logits = outputs.logits
143
+ token_boxes = encoding.bbox.squeeze().tolist()
144
+ offset_mapping = offset_mapping.squeeze().tolist()
145
+
146
+ predictions = logits.argmax(-1).squeeze().tolist()
147
+ is_subword = np.array(offset_mapping)[:, 0] != 0
148
+
149
+ true_predictions = []
150
+ true_boxes = []
151
+ true_texts = []
152
+ for idx in range(len(predictions)):
153
+ if not is_subword[idx] and token_boxes[idx] != [0, 0, 0, 0]:
154
+ true_predictions.append(id2label[predictions[idx]])
155
+ true_boxes.append(unnormalize_box(
156
+ token_boxes[idx], img.width, img.height))
157
+ true_texts.append(mapping_bbox_texts.get(
158
+ ','.join(map(str, token_boxes[idx])), ''))
159
+
160
+ if isinstance(output_path, str):
161
+ os.makedirs(output_path, exist_ok=True)
162
+ img_output = draw_output(
163
+ image=img,
164
+ true_predictions=true_predictions,
165
+ true_boxes=true_boxes
166
+ )
167
+ img_output.save(os.path.join(output_path, 'result.jpg'))
168
+ print(f"Exported result to {os.path.join(output_path, 'result.jpg')}")
169
+ return true_texts, true_predictions, true_boxes
utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from datasets import load_metric
3
+ from PIL import ImageDraw, ImageFont
4
+ import pandas as pd
5
+
6
+
7
+ metric = load_metric("seqeval")
8
+
9
+
10
+ def unnormalize_box(bbox, width, height):
11
+ return [
12
+ width * (bbox[0] / 1000),
13
+ height * (bbox[1] / 1000),
14
+ width * (bbox[2] / 1000),
15
+ height * (bbox[3] / 1000)
16
+ ]
17
+
18
+
19
+ def normalize_box(bbox, width, height):
20
+ return [
21
+ int((bbox[0] / width) * 1000),
22
+ int((bbox[1] / height) * 1000),
23
+ int((bbox[2] / width) * 1000),
24
+ int((bbox[3] / height) * 1000)
25
+ ]
26
+
27
+
28
+ def draw_output(image, true_predictions, true_boxes):
29
+ def iob_to_label(label):
30
+ label = label
31
+ if not label:
32
+ return 'other'
33
+ return label
34
+
35
+ # width, height = image.size
36
+
37
+ # predictions = logits.argmax(-1).squeeze().tolist()
38
+ # is_subword = np.array(offset_mapping)[:,0] != 0
39
+ # true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
40
+ # true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
41
+
42
+ # draw
43
+ draw = ImageDraw.Draw(image)
44
+ font = ImageFont.load_default()
45
+
46
+ for prediction, box in zip(true_predictions, true_boxes):
47
+ predicted_label = iob_to_label(prediction).lower()
48
+ draw.rectangle(box, outline='red')
49
+ draw.text((box[0] + 10, box[1] - 10),
50
+ text=predicted_label, fill='red', font=font)
51
+
52
+ return image
53
+
54
+
55
+ def create_df(true_texts,
56
+ true_predictions,
57
+ chosen_labels=['SHOP_NAME', 'ADDR', 'TITLE', 'PHONE',
58
+ 'PRODUCT_NAME', 'AMOUNT', 'UNIT', 'UPRICE', 'SUB_TPRICE', 'UDISCOUNT',
59
+ 'TAMOUNT', 'TPRICE', 'FPRICE', 'TDISCOUNT',
60
+ 'RECEMONEY', 'REMAMONEY',
61
+ 'BILLID', 'DATETIME', 'CASHIER']
62
+ ):
63
+
64
+ data = {'text': [], 'class_label': [], 'product_id': []}
65
+ product_id = -1
66
+ for text, prediction in zip(true_texts, true_predictions):
67
+ if prediction not in chosen_labels:
68
+ continue
69
+
70
+ if prediction == 'PRODUCT_NAME':
71
+ product_id += 1
72
+
73
+
74
+ if prediction in ['AMOUNT', 'UNIT', 'UDISCOUNT', 'UPRICE', 'SUB_TPRICE',
75
+ 'UDISCOUNT', 'TAMOUNT', 'TPRICE', 'FPRICE', 'TDISCOUNT',
76
+ 'RECEMONEY', 'REMAMONEY']:
77
+ text = reformat(text)
78
+
79
+
80
+ if prediction in ['AMOUNT', 'SUB_TPRICE', 'UPRICE', 'PRODUCT_NAME']:
81
+ data['product_id'].append(product_id)
82
+ else:
83
+ data['product_id'].append('')
84
+
85
+
86
+ data['class_label'].append(prediction)
87
+ data['text'].append(text)
88
+
89
+
90
+ df = pd.DataFrame(data)
91
+
92
+ return df
93
+
94
+
95
+ def reformat(text: str):
96
+ try:
97
+ text = text.replace('.', '').replace(',', '').replace(':', '').replace('/', '').replace('|', '').replace(
98
+ '\\', '').replace(')', '').replace('(', '').replace('-', '').replace(';', '').replace('_', '')
99
+ return int(text)
100
+ except:
101
+ return text
102
+
103
+ def find_product(product_name, df):
104
+ product_name = product_name.lower()
105
+ product_df = df[df['class_label'] == 'PRODUCT_NAME']
106
+ mask = product_df['text'].str.lower().str.contains(product_name, case=False, na=False)
107
+ if mask.any():
108
+ product_id = product_df.loc[mask, 'product_id'].iloc[0]
109
+ product_info = df[df['product_id'] == product_id]
110
+
111
+ prod_name = product_info.loc[product_info['class_label'] == 'PRODUCT_NAME', 'text'].iloc[0]
112
+
113
+ try:
114
+ amount = product_info.loc[product_info['class_label'] == 'AMOUNT', 'text'].iloc[0]
115
+ except:
116
+ print("Error: cannot find amount")
117
+ amount = ''
118
+
119
+ try:
120
+ uprice = product_info.loc[product_info['class_label'] == 'UPRICE', 'text'].iloc[0]
121
+ except:
122
+ print("Error: cannot find unit price")
123
+ uprice = ''
124
+
125
+ try:
126
+ sub_tprice = product_info.loc[product_info['class_label'] == 'SUB_TPRICE', 'text'].iloc[0]
127
+ except:
128
+ print("Error: cannot find sub total price")
129
+ sub_tprice = ''
130
+
131
+ #print("Sản phẩm: ", product_info.loc[product_info['class_label'] == 'PRODUCT_NAME', 'text'].iloc[0])
132
+ #print("Số lượng: ", product_info.loc[product_info['class_label'] == 'AMOUNT', 'text'].iloc[0])
133
+ #print("Đơn giá: ", product_info.loc[product_info['class_label'] == 'UPRICE', 'text'].iloc[0])
134
+ #print("Thành tiền: ", product_info.loc[product_info['class_label'] == 'SUB_TPRICE', 'text'].iloc[0])
135
+ return f"Sản phẩm: {prod_name}\n Số lượng: {amount}\n Đơn giá: {uprice}\n Thành tiền: {sub_tprice}"
136
+ else:
137
+ #print("Không tìm thấy item nào phù hợp.")
138
+ return "Không tìm thấy item nào phù hợp."
139
+ #return result = product_df['text'].str.contains(product_name, case=False, na=False).any()
140
+ #return product_df[product_df['text'].str.contains(product_name, case=False, na=False)]
141
+
142
+
143
+ def get_info(df):
144
+ try:
145
+ shop_name = df.loc[df['class_label'] == 'SHOP_NAME', 'text'].iloc[0]
146
+ except:
147
+ print("Error: cannot find shop name")
148
+ shop_name = ''
149
+ print("Tên siêu thị: ", shop_name)
150
+
151
+ try:
152
+ addr = df.loc[df['class_label'] == 'ADDR', 'text'].iloc[0]
153
+ except:
154
+ print("Error: cannot find address")
155
+ addr = ''
156
+ print("Địa chỉ: ", addr)
157
+
158
+ try:
159
+ bill_id = df.loc[df['class_label'] == 'BILLID', 'text'].iloc[0]
160
+ except:
161
+ print("Error: cannot find bill id")
162
+ bill_id = ''
163
+ print("ID hóa đơn: ", bill_id)
164
+
165
+ try:
166
+ date_time = df.loc[df['class_label'] == 'DATETIME', 'text'].iloc[0]
167
+ except:
168
+ print("Error: cannot find date and time")
169
+ date_time = ''
170
+ print("Ngày: ", date_time)
171
+
172
+ try:
173
+ cashier = df.loc[df['class_label'] == 'CASHIER', 'text'].iloc[0]
174
+ except:
175
+ print("Error: cannot find cashier")
176
+ cashier = ''
177
+ print("Nhân viên: ", cashier)
178
+
179
+ return f"Tên siêu thị: {shop_name}\n Địa chỉ: {addr}\n ID hóa đơn: {bill_id}\n Ngày: {date_time}\n Nhân viên: {cashier}\n"