KpLBaTMaN
commited on
Commit
·
90aefec
1
Parent(s):
8e8b282
code
Browse files- modeling_GOT.py +66 -59
modeling_GOT.py
CHANGED
@@ -15,6 +15,7 @@ import dataclasses
|
|
15 |
import numpy as np
|
16 |
import cv2
|
17 |
from io import BytesIO
|
|
|
18 |
###
|
19 |
|
20 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
@@ -501,15 +502,24 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
501 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
502 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
503 |
|
504 |
-
def chat(
|
505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
506 |
self.disable_torch_init()
|
507 |
|
508 |
-
|
509 |
-
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
510 |
-
|
511 |
use_im_start_end = True
|
512 |
-
|
513 |
image_token_len = 256
|
514 |
|
515 |
if gradio_input:
|
@@ -518,7 +528,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
518 |
image = self.load_image(image_file)
|
519 |
|
520 |
w, h = image.size
|
521 |
-
|
522 |
if ocr_type == 'format':
|
523 |
qs = 'OCR with format: '
|
524 |
else:
|
@@ -527,13 +537,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
527 |
if ocr_box:
|
528 |
bbox = eval(ocr_box)
|
529 |
if len(bbox) == 2:
|
530 |
-
bbox[0] = int(bbox[0]/w*1000)
|
531 |
-
bbox[1] = int(bbox[1]/h*1000)
|
532 |
if len(bbox) == 4:
|
533 |
-
bbox[0] = int(bbox[0]/w*1000)
|
534 |
-
bbox[1] = int(bbox[1]/h*1000)
|
535 |
-
bbox[2] = int(bbox[2]/w*1000)
|
536 |
-
bbox[3] = int(bbox[3]/h*1000)
|
537 |
if ocr_type == 'format':
|
538 |
qs = str(bbox) + ' ' + 'OCR with format: '
|
539 |
else:
|
@@ -546,15 +556,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
546 |
qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
|
547 |
|
548 |
if use_im_start_end:
|
549 |
-
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
|
550 |
else:
|
551 |
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
552 |
|
553 |
-
|
554 |
conv_mpt = Conversation(
|
555 |
system="""<|im_start|>system
|
556 |
-
|
557 |
-
# system = None,
|
558 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
559 |
version="mpt",
|
560 |
messages=(),
|
@@ -572,43 +580,47 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
572 |
print(prompt)
|
573 |
|
574 |
inputs = tokenizer([prompt])
|
575 |
-
|
576 |
-
image_tensor_1 = image_processor_high(image)
|
577 |
-
|
578 |
-
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
579 |
|
580 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
581 |
keywords = [stop_str]
|
582 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
583 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
584 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
if stream_flag:
|
586 |
-
with
|
587 |
output_ids = self.generate(
|
588 |
input_ids,
|
589 |
-
images=[image_tensor_1.unsqueeze(0).half().
|
590 |
do_sample=False,
|
591 |
-
num_beams
|
592 |
-
no_repeat_ngram_size
|
593 |
streamer=streamer,
|
594 |
max_new_tokens=4096,
|
595 |
stopping_criteria=[stopping_criteria]
|
596 |
-
|
597 |
else:
|
598 |
-
with
|
599 |
output_ids = self.generate(
|
600 |
input_ids,
|
601 |
-
images=[image_tensor_1.unsqueeze(0).half().
|
602 |
do_sample=False,
|
603 |
-
num_beams
|
604 |
-
no_repeat_ngram_size
|
605 |
-
# streamer=streamer,
|
606 |
max_new_tokens=4096,
|
607 |
stopping_criteria=[stopping_criteria]
|
608 |
-
|
609 |
-
|
610 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
611 |
-
|
612 |
if outputs.endswith(stop_str):
|
613 |
outputs = outputs[:-len(stop_str)]
|
614 |
outputs = outputs.strip()
|
@@ -622,46 +634,44 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
622 |
import verovio
|
623 |
tk = verovio.toolkit()
|
624 |
tk.loadData(outputs)
|
625 |
-
tk.setOptions({
|
626 |
-
|
627 |
-
|
|
|
|
|
|
|
|
|
|
|
628 |
tk.getPageCount()
|
629 |
svg = tk.renderToSVG()
|
630 |
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
|
631 |
-
|
632 |
svg_to_html(svg, save_render_file)
|
633 |
|
634 |
if ocr_type == 'format' and '**kern' not in outputs:
|
635 |
-
|
636 |
-
|
637 |
-
if '\\begin{tikzpicture}' not in outputs:
|
638 |
html_path_2 = save_render_file
|
639 |
right_num = outputs.count('\\right')
|
640 |
-
left_num = outputs.count('
|
641 |
-
|
642 |
if right_num != left_num:
|
643 |
-
outputs = outputs.replace('\left(', '(').replace('\\right)', ')')
|
644 |
-
|
645 |
-
|
|
|
|
|
646 |
outputs = outputs.replace('"', '``').replace('$', '')
|
647 |
-
|
648 |
outputs_list = outputs.split('\n')
|
649 |
-
gt= ''
|
650 |
for out in outputs_list:
|
651 |
-
gt +=
|
652 |
-
|
653 |
gt = gt[:-2]
|
654 |
-
|
655 |
-
|
656 |
lines = content_mmd_to_html
|
657 |
lines = lines.split("const text =")
|
658 |
-
new_web = lines[0] + 'const text ='
|
659 |
-
|
660 |
else:
|
661 |
html_path_2 = save_render_file
|
662 |
outputs = outputs.translate(translation_table)
|
663 |
outputs_list = outputs.split('\n')
|
664 |
-
gt= ''
|
665 |
for out in outputs_list:
|
666 |
if out:
|
667 |
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
|
@@ -669,7 +679,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
669 |
out = out[:-1]
|
670 |
if out is None:
|
671 |
break
|
672 |
-
|
673 |
if out:
|
674 |
if out[-1] != ';':
|
675 |
gt += out[:-1] + ';\n'
|
@@ -677,14 +686,12 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
677 |
gt += out + '\n'
|
678 |
else:
|
679 |
gt += out + '\n'
|
680 |
-
|
681 |
-
|
682 |
lines = tik_html
|
683 |
lines = lines.split("const text =")
|
684 |
new_web = lines[0] + gt + lines[1]
|
685 |
-
|
686 |
with open(html_path_2, 'w') as web_f_new:
|
687 |
web_f_new.write(new_web)
|
|
|
688 |
return response_str
|
689 |
|
690 |
def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
|
|
|
15 |
import numpy as np
|
16 |
import cv2
|
17 |
from io import BytesIO
|
18 |
+
import contextlib
|
19 |
###
|
20 |
|
21 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
|
|
502 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
503 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
504 |
|
505 |
+
def chat(
|
506 |
+
self,
|
507 |
+
tokenizer,
|
508 |
+
image_file,
|
509 |
+
ocr_type,
|
510 |
+
ocr_box='',
|
511 |
+
ocr_color='',
|
512 |
+
render=False,
|
513 |
+
save_render_file=None,
|
514 |
+
print_prompt=False,
|
515 |
+
gradio_input=False,
|
516 |
+
stream_flag=False,
|
517 |
+
device="cuda" # new parameter to specify the device
|
518 |
+
):
|
519 |
self.disable_torch_init()
|
520 |
|
521 |
+
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
|
|
|
|
522 |
use_im_start_end = True
|
|
|
523 |
image_token_len = 256
|
524 |
|
525 |
if gradio_input:
|
|
|
528 |
image = self.load_image(image_file)
|
529 |
|
530 |
w, h = image.size
|
531 |
+
|
532 |
if ocr_type == 'format':
|
533 |
qs = 'OCR with format: '
|
534 |
else:
|
|
|
537 |
if ocr_box:
|
538 |
bbox = eval(ocr_box)
|
539 |
if len(bbox) == 2:
|
540 |
+
bbox[0] = int(bbox[0] / w * 1000)
|
541 |
+
bbox[1] = int(bbox[1] / h * 1000)
|
542 |
if len(bbox) == 4:
|
543 |
+
bbox[0] = int(bbox[0] / w * 1000)
|
544 |
+
bbox[1] = int(bbox[1] / h * 1000)
|
545 |
+
bbox[2] = int(bbox[2] / w * 1000)
|
546 |
+
bbox[3] = int(bbox[3] / h * 1000)
|
547 |
if ocr_type == 'format':
|
548 |
qs = str(bbox) + ' ' + 'OCR with format: '
|
549 |
else:
|
|
|
556 |
qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
|
557 |
|
558 |
if use_im_start_end:
|
559 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
|
560 |
else:
|
561 |
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
562 |
|
|
|
563 |
conv_mpt = Conversation(
|
564 |
system="""<|im_start|>system
|
565 |
+
You should follow the instructions carefully and explain your answers in detail.""",
|
|
|
566 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
567 |
version="mpt",
|
568 |
messages=(),
|
|
|
580 |
print(prompt)
|
581 |
|
582 |
inputs = tokenizer([prompt])
|
583 |
+
input_ids = torch.as_tensor(inputs.input_ids).to(device)
|
|
|
|
|
|
|
584 |
|
585 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
586 |
keywords = [stop_str]
|
587 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
588 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
589 |
|
590 |
+
image_tensor_1 = image_processor_high(image)
|
591 |
+
|
592 |
+
# Use autocast only when on CUDA, otherwise use a null context for CPU
|
593 |
+
if device == "cuda":
|
594 |
+
autocast_context = torch.autocast("cuda", dtype=torch.bfloat16)
|
595 |
+
else:
|
596 |
+
autocast_context = contextlib.nullcontext()
|
597 |
+
|
598 |
if stream_flag:
|
599 |
+
with autocast_context:
|
600 |
output_ids = self.generate(
|
601 |
input_ids,
|
602 |
+
images=[image_tensor_1.unsqueeze(0).half().to(device)],
|
603 |
do_sample=False,
|
604 |
+
num_beams=1,
|
605 |
+
no_repeat_ngram_size=20,
|
606 |
streamer=streamer,
|
607 |
max_new_tokens=4096,
|
608 |
stopping_criteria=[stopping_criteria]
|
609 |
+
)
|
610 |
else:
|
611 |
+
with autocast_context:
|
612 |
output_ids = self.generate(
|
613 |
input_ids,
|
614 |
+
images=[image_tensor_1.unsqueeze(0).half().to(device)],
|
615 |
do_sample=False,
|
616 |
+
num_beams=1,
|
617 |
+
no_repeat_ngram_size=20,
|
|
|
618 |
max_new_tokens=4096,
|
619 |
stopping_criteria=[stopping_criteria]
|
620 |
+
)
|
621 |
+
|
622 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
623 |
+
|
624 |
if outputs.endswith(stop_str):
|
625 |
outputs = outputs[:-len(stop_str)]
|
626 |
outputs = outputs.strip()
|
|
|
634 |
import verovio
|
635 |
tk = verovio.toolkit()
|
636 |
tk.loadData(outputs)
|
637 |
+
tk.setOptions({
|
638 |
+
"pageWidth": 2100,
|
639 |
+
"footer": 'none',
|
640 |
+
'barLineWidth': 0.5,
|
641 |
+
'beamMaxSlope': 15,
|
642 |
+
'staffLineWidth': 0.2,
|
643 |
+
'spacingStaff': 6
|
644 |
+
})
|
645 |
tk.getPageCount()
|
646 |
svg = tk.renderToSVG()
|
647 |
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
|
|
|
648 |
svg_to_html(svg, save_render_file)
|
649 |
|
650 |
if ocr_type == 'format' and '**kern' not in outputs:
|
651 |
+
if '\\begin{tikzpicture}' not in outputs:
|
|
|
|
|
652 |
html_path_2 = save_render_file
|
653 |
right_num = outputs.count('\\right')
|
654 |
+
left_num = outputs.count('\\left')
|
|
|
655 |
if right_num != left_num:
|
656 |
+
outputs = outputs.replace('\left(', '(').replace('\\right)', ')')\
|
657 |
+
.replace('\left[', '[').replace('\\right]', ']')\
|
658 |
+
.replace('\left{', '{').replace('\\right}', '}')\
|
659 |
+
.replace('\left|', '|').replace('\\right|', '|')\
|
660 |
+
.replace('\left.', '.').replace('\\right.', '.')
|
661 |
outputs = outputs.replace('"', '``').replace('$', '')
|
|
|
662 |
outputs_list = outputs.split('\n')
|
663 |
+
gt = ''
|
664 |
for out in outputs_list:
|
665 |
+
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
|
|
|
666 |
gt = gt[:-2]
|
|
|
|
|
667 |
lines = content_mmd_to_html
|
668 |
lines = lines.split("const text =")
|
669 |
+
new_web = lines[0] + 'const text =' + gt + lines[1]
|
|
|
670 |
else:
|
671 |
html_path_2 = save_render_file
|
672 |
outputs = outputs.translate(translation_table)
|
673 |
outputs_list = outputs.split('\n')
|
674 |
+
gt = ''
|
675 |
for out in outputs_list:
|
676 |
if out:
|
677 |
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
|
|
|
679 |
out = out[:-1]
|
680 |
if out is None:
|
681 |
break
|
|
|
682 |
if out:
|
683 |
if out[-1] != ';':
|
684 |
gt += out[:-1] + ';\n'
|
|
|
686 |
gt += out + '\n'
|
687 |
else:
|
688 |
gt += out + '\n'
|
|
|
|
|
689 |
lines = tik_html
|
690 |
lines = lines.split("const text =")
|
691 |
new_web = lines[0] + gt + lines[1]
|
|
|
692 |
with open(html_path_2, 'w') as web_f_new:
|
693 |
web_f_new.write(new_web)
|
694 |
+
|
695 |
return response_str
|
696 |
|
697 |
def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
|