KpLBaTMaN commited on
Commit
90aefec
·
1 Parent(s): 8e8b282
Files changed (1) hide show
  1. modeling_GOT.py +66 -59
modeling_GOT.py CHANGED
@@ -15,6 +15,7 @@ import dataclasses
15
  import numpy as np
16
  import cv2
17
  from io import BytesIO
 
18
  ###
19
 
20
  DEFAULT_IMAGE_TOKEN = "<image>"
@@ -501,15 +502,24 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
501
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
502
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
503
 
504
- def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
505
-
 
 
 
 
 
 
 
 
 
 
 
 
506
  self.disable_torch_init()
507
 
508
-
509
- image_processor_high = GOTImageEvalProcessor(image_size=1024)
510
-
511
  use_im_start_end = True
512
-
513
  image_token_len = 256
514
 
515
  if gradio_input:
@@ -518,7 +528,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
518
  image = self.load_image(image_file)
519
 
520
  w, h = image.size
521
-
522
  if ocr_type == 'format':
523
  qs = 'OCR with format: '
524
  else:
@@ -527,13 +537,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
527
  if ocr_box:
528
  bbox = eval(ocr_box)
529
  if len(bbox) == 2:
530
- bbox[0] = int(bbox[0]/w*1000)
531
- bbox[1] = int(bbox[1]/h*1000)
532
  if len(bbox) == 4:
533
- bbox[0] = int(bbox[0]/w*1000)
534
- bbox[1] = int(bbox[1]/h*1000)
535
- bbox[2] = int(bbox[2]/w*1000)
536
- bbox[3] = int(bbox[3]/h*1000)
537
  if ocr_type == 'format':
538
  qs = str(bbox) + ' ' + 'OCR with format: '
539
  else:
@@ -546,15 +556,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
546
  qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
547
 
548
  if use_im_start_end:
549
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
550
  else:
551
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
552
 
553
-
554
  conv_mpt = Conversation(
555
  system="""<|im_start|>system
556
- You should follow the instructions carefully and explain your answers in detail.""",
557
- # system = None,
558
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
559
  version="mpt",
560
  messages=(),
@@ -572,43 +580,47 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
572
  print(prompt)
573
 
574
  inputs = tokenizer([prompt])
575
-
576
- image_tensor_1 = image_processor_high(image)
577
-
578
- input_ids = torch.as_tensor(inputs.input_ids).cuda()
579
 
580
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
581
  keywords = [stop_str]
582
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
583
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
584
 
 
 
 
 
 
 
 
 
585
  if stream_flag:
586
- with torch.autocast("cuda", dtype=torch.bfloat16):
587
  output_ids = self.generate(
588
  input_ids,
589
- images=[image_tensor_1.unsqueeze(0).half().cuda()],
590
  do_sample=False,
591
- num_beams = 1,
592
- no_repeat_ngram_size = 20,
593
  streamer=streamer,
594
  max_new_tokens=4096,
595
  stopping_criteria=[stopping_criteria]
596
- )
597
  else:
598
- with torch.autocast("cuda", dtype=torch.bfloat16):
599
  output_ids = self.generate(
600
  input_ids,
601
- images=[image_tensor_1.unsqueeze(0).half().cuda()],
602
  do_sample=False,
603
- num_beams = 1,
604
- no_repeat_ngram_size = 20,
605
- # streamer=streamer,
606
  max_new_tokens=4096,
607
  stopping_criteria=[stopping_criteria]
608
- )
609
-
610
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
611
-
612
  if outputs.endswith(stop_str):
613
  outputs = outputs[:-len(stop_str)]
614
  outputs = outputs.strip()
@@ -622,46 +634,44 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
622
  import verovio
623
  tk = verovio.toolkit()
624
  tk.loadData(outputs)
625
- tk.setOptions({"pageWidth": 2100, "footer": 'none',
626
- 'barLineWidth': 0.5, 'beamMaxSlope': 15,
627
- 'staffLineWidth': 0.2, 'spacingStaff': 6})
 
 
 
 
 
628
  tk.getPageCount()
629
  svg = tk.renderToSVG()
630
  svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
631
-
632
  svg_to_html(svg, save_render_file)
633
 
634
  if ocr_type == 'format' and '**kern' not in outputs:
635
-
636
-
637
- if '\\begin{tikzpicture}' not in outputs:
638
  html_path_2 = save_render_file
639
  right_num = outputs.count('\\right')
640
- left_num = outputs.count('\left')
641
-
642
  if right_num != left_num:
643
- outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
644
-
645
-
 
 
646
  outputs = outputs.replace('"', '``').replace('$', '')
647
-
648
  outputs_list = outputs.split('\n')
649
- gt= ''
650
  for out in outputs_list:
651
- gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
652
-
653
  gt = gt[:-2]
654
-
655
-
656
  lines = content_mmd_to_html
657
  lines = lines.split("const text =")
658
- new_web = lines[0] + 'const text =' + gt + lines[1]
659
-
660
  else:
661
  html_path_2 = save_render_file
662
  outputs = outputs.translate(translation_table)
663
  outputs_list = outputs.split('\n')
664
- gt= ''
665
  for out in outputs_list:
666
  if out:
667
  if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
@@ -669,7 +679,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
669
  out = out[:-1]
670
  if out is None:
671
  break
672
-
673
  if out:
674
  if out[-1] != ';':
675
  gt += out[:-1] + ';\n'
@@ -677,14 +686,12 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
677
  gt += out + '\n'
678
  else:
679
  gt += out + '\n'
680
-
681
-
682
  lines = tik_html
683
  lines = lines.split("const text =")
684
  new_web = lines[0] + gt + lines[1]
685
-
686
  with open(html_path_2, 'w') as web_f_new:
687
  web_f_new.write(new_web)
 
688
  return response_str
689
 
690
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
 
15
  import numpy as np
16
  import cv2
17
  from io import BytesIO
18
+ import contextlib
19
  ###
20
 
21
  DEFAULT_IMAGE_TOKEN = "<image>"
 
502
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
503
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
504
 
505
+ def chat(
506
+ self,
507
+ tokenizer,
508
+ image_file,
509
+ ocr_type,
510
+ ocr_box='',
511
+ ocr_color='',
512
+ render=False,
513
+ save_render_file=None,
514
+ print_prompt=False,
515
+ gradio_input=False,
516
+ stream_flag=False,
517
+ device="cuda" # new parameter to specify the device
518
+ ):
519
  self.disable_torch_init()
520
 
521
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
 
 
522
  use_im_start_end = True
 
523
  image_token_len = 256
524
 
525
  if gradio_input:
 
528
  image = self.load_image(image_file)
529
 
530
  w, h = image.size
531
+
532
  if ocr_type == 'format':
533
  qs = 'OCR with format: '
534
  else:
 
537
  if ocr_box:
538
  bbox = eval(ocr_box)
539
  if len(bbox) == 2:
540
+ bbox[0] = int(bbox[0] / w * 1000)
541
+ bbox[1] = int(bbox[1] / h * 1000)
542
  if len(bbox) == 4:
543
+ bbox[0] = int(bbox[0] / w * 1000)
544
+ bbox[1] = int(bbox[1] / h * 1000)
545
+ bbox[2] = int(bbox[2] / w * 1000)
546
+ bbox[3] = int(bbox[3] / h * 1000)
547
  if ocr_type == 'format':
548
  qs = str(bbox) + ' ' + 'OCR with format: '
549
  else:
 
556
  qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
557
 
558
  if use_im_start_end:
559
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
560
  else:
561
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
562
 
 
563
  conv_mpt = Conversation(
564
  system="""<|im_start|>system
565
+ You should follow the instructions carefully and explain your answers in detail.""",
 
566
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
567
  version="mpt",
568
  messages=(),
 
580
  print(prompt)
581
 
582
  inputs = tokenizer([prompt])
583
+ input_ids = torch.as_tensor(inputs.input_ids).to(device)
 
 
 
584
 
585
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
586
  keywords = [stop_str]
587
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
588
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
589
 
590
+ image_tensor_1 = image_processor_high(image)
591
+
592
+ # Use autocast only when on CUDA, otherwise use a null context for CPU
593
+ if device == "cuda":
594
+ autocast_context = torch.autocast("cuda", dtype=torch.bfloat16)
595
+ else:
596
+ autocast_context = contextlib.nullcontext()
597
+
598
  if stream_flag:
599
+ with autocast_context:
600
  output_ids = self.generate(
601
  input_ids,
602
+ images=[image_tensor_1.unsqueeze(0).half().to(device)],
603
  do_sample=False,
604
+ num_beams=1,
605
+ no_repeat_ngram_size=20,
606
  streamer=streamer,
607
  max_new_tokens=4096,
608
  stopping_criteria=[stopping_criteria]
609
+ )
610
  else:
611
+ with autocast_context:
612
  output_ids = self.generate(
613
  input_ids,
614
+ images=[image_tensor_1.unsqueeze(0).half().to(device)],
615
  do_sample=False,
616
+ num_beams=1,
617
+ no_repeat_ngram_size=20,
 
618
  max_new_tokens=4096,
619
  stopping_criteria=[stopping_criteria]
620
+ )
621
+
622
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
623
+
624
  if outputs.endswith(stop_str):
625
  outputs = outputs[:-len(stop_str)]
626
  outputs = outputs.strip()
 
634
  import verovio
635
  tk = verovio.toolkit()
636
  tk.loadData(outputs)
637
+ tk.setOptions({
638
+ "pageWidth": 2100,
639
+ "footer": 'none',
640
+ 'barLineWidth': 0.5,
641
+ 'beamMaxSlope': 15,
642
+ 'staffLineWidth': 0.2,
643
+ 'spacingStaff': 6
644
+ })
645
  tk.getPageCount()
646
  svg = tk.renderToSVG()
647
  svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
 
648
  svg_to_html(svg, save_render_file)
649
 
650
  if ocr_type == 'format' and '**kern' not in outputs:
651
+ if '\\begin{tikzpicture}' not in outputs:
 
 
652
  html_path_2 = save_render_file
653
  right_num = outputs.count('\\right')
654
+ left_num = outputs.count('\\left')
 
655
  if right_num != left_num:
656
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')')\
657
+ .replace('\left[', '[').replace('\\right]', ']')\
658
+ .replace('\left{', '{').replace('\\right}', '}')\
659
+ .replace('\left|', '|').replace('\\right|', '|')\
660
+ .replace('\left.', '.').replace('\\right.', '.')
661
  outputs = outputs.replace('"', '``').replace('$', '')
 
662
  outputs_list = outputs.split('\n')
663
+ gt = ''
664
  for out in outputs_list:
665
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
 
666
  gt = gt[:-2]
 
 
667
  lines = content_mmd_to_html
668
  lines = lines.split("const text =")
669
+ new_web = lines[0] + 'const text =' + gt + lines[1]
 
670
  else:
671
  html_path_2 = save_render_file
672
  outputs = outputs.translate(translation_table)
673
  outputs_list = outputs.split('\n')
674
+ gt = ''
675
  for out in outputs_list:
676
  if out:
677
  if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
 
679
  out = out[:-1]
680
  if out is None:
681
  break
 
682
  if out:
683
  if out[-1] != ';':
684
  gt += out[:-1] + ';\n'
 
686
  gt += out + '\n'
687
  else:
688
  gt += out + '\n'
 
 
689
  lines = tik_html
690
  lines = lines.split("const text =")
691
  new_web = lines[0] + gt + lines[1]
 
692
  with open(html_path_2, 'w') as web_f_new:
693
  web_f_new.write(new_web)
694
+
695
  return response_str
696
 
697
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):