DamLoan commited on
Commit
b79f44c
·
verified ·
1 Parent(s): b83ba6e

Upload llm_utils.py

Browse files
Files changed (1) hide show
  1. llm_utils.py +116 -0
llm_utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ from PIL import Image
4
+ from transformers import AutoModel, AutoTokenizer
5
+ from torchvision.transforms.functional import InterpolationMode
6
+
7
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
8
+ IMAGENET_STD = (0.229, 0.224, 0.225)
9
+
10
+
11
+ def build_transform(input_size):
12
+ return T.Compose([
13
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
14
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
15
+ T.ToTensor(),
16
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
17
+ ])
18
+
19
+
20
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
21
+ best_ratio_diff = float('inf')
22
+ best_ratio = (1, 1)
23
+ area = width * height
24
+ for ratio in target_ratios:
25
+ target_aspect_ratio = ratio[0] / ratio[1]
26
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
27
+ if ratio_diff < best_ratio_diff:
28
+ best_ratio_diff = ratio_diff
29
+ best_ratio = ratio
30
+ elif ratio_diff == best_ratio_diff:
31
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
32
+ best_ratio = ratio
33
+ return best_ratio
34
+
35
+
36
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
37
+ orig_width, orig_height = image.size
38
+ aspect_ratio = orig_width / orig_height
39
+
40
+ target_ratios = set(
41
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1)
42
+ for j in range(1, n + 1) if min_num <= i * j <= max_num
43
+ )
44
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
45
+
46
+ target_aspect_ratio = find_closest_aspect_ratio(
47
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
48
+
49
+ target_width = image_size * target_aspect_ratio[0]
50
+ target_height = image_size * target_aspect_ratio[1]
51
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
52
+
53
+ resized_img = image.resize((target_width, target_height))
54
+ processed_images = []
55
+
56
+ for i in range(blocks):
57
+ box = (
58
+ (i % (target_width // image_size)) * image_size,
59
+ (i // (target_width // image_size)) * image_size,
60
+ ((i % (target_width // image_size)) + 1) * image_size,
61
+ ((i // (target_width // image_size)) + 1) * image_size
62
+ )
63
+ split_img = resized_img.crop(box)
64
+ processed_images.append(split_img)
65
+
66
+ if use_thumbnail and len(processed_images) != 1:
67
+ thumbnail_img = image.resize((image_size, image_size))
68
+ processed_images.append(thumbnail_img)
69
+
70
+ return processed_images
71
+
72
+
73
+ def load_image(image_file, input_size=448, max_num=12):
74
+ image = Image.open(image_file).convert('RGB')
75
+ transform = build_transform(input_size)
76
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
77
+ pixel_values = [transform(im) for im in images]
78
+ return torch.stack(pixel_values)
79
+
80
+
81
+ def load_model():
82
+ model_name = "5CD-AI/Vintern-1B-v3_5"
83
+ try:
84
+ model = AutoModel.from_pretrained(
85
+ model_name,
86
+ torch_dtype=torch.bfloat16,
87
+ low_cpu_mem_usage=True,
88
+ trust_remote_code=True,
89
+ use_flash_attn=False
90
+ ).eval().cuda()
91
+ except Exception:
92
+ model = AutoModel.from_pretrained(
93
+ model_name,
94
+ torch_dtype=torch.bfloat16,
95
+ low_cpu_mem_usage=True,
96
+ trust_remote_code=True
97
+ ).eval().cuda()
98
+
99
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
100
+ return model, tokenizer
101
+
102
+
103
+ def extract_info_from_image(image_path, model, tokenizer, max_num_blocks=6):
104
+ pixel_values = load_image(image_path, max_num=max_num_blocks).to(torch.bfloat16).cuda()
105
+
106
+ question = "<image>\nTrích xuất dữ liệu các cột: STT, Mã số thuế, Tên người nộp thuế, Địa chỉ, Số tiền thuế nợ, Biện pháp cưỡng chế. Hãy cố gắng đọc rõ những con số hoặc chữ bị đóng dấu và trả về dạng markdown."
107
+
108
+ generation_config = dict(
109
+ max_new_tokens=2048,
110
+ do_sample=False,
111
+ num_beams=3,
112
+ repetition_penalty=2.5
113
+ )
114
+
115
+ response = model.chat(tokenizer, pixel_values, question, generation_config)
116
+ return response