justinj92 commited on
Commit
c6a0eef
·
verified ·
1 Parent(s): 6140e0e

Update utils/models.py

Browse files
Files changed (1) hide show
  1. utils/models.py +70 -72
utils/models.py CHANGED
@@ -1,73 +1,71 @@
1
- from typing import Tuple, Dict, Any, List
2
- from unittest.mock import patch
3
-
4
- import numpy as np
5
- import supervision as sv
6
- import torch
7
- from PIL import Image
8
- from transformers import AutoModelForCausalLM, AutoProcessor
9
-
10
- from utils.imports import fixed_get_imports
11
-
12
- CHECKPOINTS = [
13
- "microsoft/Florence-2-large-ft",
14
- "microsoft/Florence-2-large",
15
- "microsoft/Florence-2-base-ft",
16
- "microsoft/Florence-2-base",
17
- ]
18
-
19
-
20
- def load_models(device: torch.device) -> Tuple[Dict[str, Any], Dict[str, Any]]:
21
- with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
22
- models = {}
23
- processors = {}
24
- for checkpoint in CHECKPOINTS:
25
- models[checkpoint] = AutoModelForCausalLM.from_pretrained(
26
- checkpoint, trust_remote_code=True).to(device).eval()
27
- processors[checkpoint] = AutoProcessor.from_pretrained(
28
- checkpoint, trust_remote_code=True)
29
- return models, processors
30
-
31
-
32
- def run_inference(
33
- model: Any,
34
- processor: Any,
35
- device: torch.device,
36
- image: Image,
37
- task: str,
38
- text: str = ""
39
- ) -> Tuple[str, Dict]:
40
- prompt = task + text
41
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
42
- generated_ids = model.generate(
43
- input_ids=inputs["input_ids"],
44
- pixel_values=inputs["pixel_values"],
45
- max_new_tokens=1024,
46
- num_beams=3
47
- )
48
- generated_text = processor.batch_decode(
49
- generated_ids, skip_special_tokens=False)[0]
50
- response = processor.post_process_generation(
51
- generated_text, task=task, image_size=image.size)
52
- return generated_text, response
53
-
54
-
55
- def pre_process_region_task_input(
56
- prompt: List[float],
57
- resolution_wh: Tuple[int, int]
58
- ) -> str:
59
- x1, y1, _, x2, y2, _ = prompt
60
- w, h = resolution_wh
61
- box = np.array([x1, y1, x2, y2])
62
- box /= np.array([w, h, w, h])
63
- box *= 1000
64
- return "".join([f"<loc_{int(coordinate)}>" for coordinate in box])
65
-
66
-
67
- def post_process_region_output(
68
- detections: sv.Detections,
69
- resolution_wh: Tuple[int, int]
70
- ) -> sv.Detections:
71
- w, h = resolution_wh
72
- detections.xyxy = (detections.xyxy / 1000 * np.array([w, h, w, h])).astype(np.int32)
73
  return detections
 
1
+ from typing import Tuple, Dict, Any, List
2
+ from unittest.mock import patch
3
+
4
+ import numpy as np
5
+ import supervision as sv
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import AutoModelForCausalLM, AutoProcessor
9
+
10
+ from utils.imports import fixed_get_imports
11
+
12
+ CHECKPOINTS = [
13
+ "microsoft/Florence-2-large-ft",
14
+ "microsoft/Florence-2-base-ft",
15
+ ]
16
+
17
+
18
+ def load_models(device: torch.device) -> Tuple[Dict[str, Any], Dict[str, Any]]:
19
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
20
+ models = {}
21
+ processors = {}
22
+ for checkpoint in CHECKPOINTS:
23
+ models[checkpoint] = AutoModelForCausalLM.from_pretrained(
24
+ checkpoint, trust_remote_code=True).to(device).eval()
25
+ processors[checkpoint] = AutoProcessor.from_pretrained(
26
+ checkpoint, trust_remote_code=True)
27
+ return models, processors
28
+
29
+
30
+ def run_inference(
31
+ model: Any,
32
+ processor: Any,
33
+ device: torch.device,
34
+ image: Image,
35
+ task: str,
36
+ text: str = ""
37
+ ) -> Tuple[str, Dict]:
38
+ prompt = task + text
39
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
40
+ generated_ids = model.generate(
41
+ input_ids=inputs["input_ids"],
42
+ pixel_values=inputs["pixel_values"],
43
+ max_new_tokens=1024,
44
+ num_beams=3
45
+ )
46
+ generated_text = processor.batch_decode(
47
+ generated_ids, skip_special_tokens=False)[0]
48
+ response = processor.post_process_generation(
49
+ generated_text, task=task, image_size=image.size)
50
+ return generated_text, response
51
+
52
+
53
+ def pre_process_region_task_input(
54
+ prompt: List[float],
55
+ resolution_wh: Tuple[int, int]
56
+ ) -> str:
57
+ x1, y1, _, x2, y2, _ = prompt
58
+ w, h = resolution_wh
59
+ box = np.array([x1, y1, x2, y2])
60
+ box /= np.array([w, h, w, h])
61
+ box *= 1000
62
+ return "".join([f"<loc_{int(coordinate)}>" for coordinate in box])
63
+
64
+
65
+ def post_process_region_output(
66
+ detections: sv.Detections,
67
+ resolution_wh: Tuple[int, int]
68
+ ) -> sv.Detections:
69
+ w, h = resolution_wh
70
+ detections.xyxy = (detections.xyxy / 1000 * np.array([w, h, w, h])).astype(np.int32)
 
 
71
  return detections