Shivdutta commited on
Commit
1f515b4
·
verified ·
1 Parent(s): 7bb1029

Delete utils_inference.py

Browse files
Files changed (1) hide show
  1. utils_inference.py +0 -119
utils_inference.py DELETED
@@ -1,119 +0,0 @@
1
- import numpy as np
2
- import gradio as gr
3
- from PIL import ImageDraw
4
-
5
- from tools_gradio import fast_process
6
- from tools import format_results, box_prompt, point_prompt, text_prompt
7
-
8
-
9
- def segment_everything(
10
- model,
11
- device,
12
- input,
13
- input_size=1024,
14
- iou_threshold=0.7,
15
- conf_threshold=0.25,
16
- better_quality=False,
17
- withContours=True,
18
- use_retina=True,
19
- text="",
20
- wider=False,
21
- mask_random_color=True,
22
- ):
23
- input_size = int(input_size)
24
- w, h = input.size
25
- scale = input_size / max(w, h)
26
- new_w = int(w * scale)
27
- new_h = int(h * scale)
28
- input = input.resize((new_w, new_h))
29
-
30
- results = model(input,
31
- device=device,
32
- retina_masks=True,
33
- iou=iou_threshold,
34
- conf=conf_threshold,
35
- imgsz=input_size, )
36
-
37
- if len(text) > 0:
38
- results = format_results(results[0], 0)
39
- annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
40
- annotations = np.array([annotations])
41
- else:
42
- annotations = results[0].masks.data
43
-
44
- fig = fast_process(annotations=annotations,
45
- image=input,
46
- device=device,
47
- scale=(1024 // input_size),
48
- better_quality=better_quality,
49
- mask_random_color=mask_random_color,
50
- bbox=None,
51
- use_retina=use_retina,
52
- withContours=withContours, )
53
- return fig
54
-
55
-
56
- def segment_with_points(
57
- model,
58
- device,
59
- input,
60
- input_size=1024,
61
- iou_threshold=0.7,
62
- conf_threshold=0.25,
63
- better_quality=False,
64
- withContours=True,
65
- use_retina=True,
66
- mask_random_color=True,
67
- ):
68
- global global_points
69
- global global_point_label
70
-
71
- input_size = int(input_size)
72
- w, h = input.size
73
- scale = input_size / max(w, h)
74
- new_w = int(w * scale)
75
- new_h = int(h * scale)
76
- input = input.resize((new_w, new_h))
77
-
78
- scaled_points = [[int(x * scale) for x in point] for point in global_points]
79
-
80
- results = model(input,
81
- device=device,
82
- retina_masks=True,
83
- iou=iou_threshold,
84
- conf=conf_threshold,
85
- imgsz=input_size, )
86
-
87
- results = format_results(results[0], 0)
88
- annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
89
- annotations = np.array([annotations])
90
-
91
- fig = fast_process(annotations=annotations,
92
- image=input,
93
- device=device,
94
- scale=(1024 // input_size),
95
- better_quality=better_quality,
96
- mask_random_color=mask_random_color,
97
- bbox=None,
98
- use_retina=use_retina,
99
- withContours=withContours, )
100
-
101
- global_points = []
102
- global_point_label = []
103
- return fig, None
104
-
105
-
106
- def get_points_with_draw(image, label, evt: gr.SelectData):
107
- global global_points
108
- global global_point_label
109
-
110
- x, y = evt.index[0], evt.index[1]
111
- point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
112
- global_points.append([x, y])
113
- global_point_label.append(1 if label == 'Add Mask' else 0)
114
-
115
- print(x, y, label == 'Add Mask')
116
-
117
- draw = ImageDraw.Draw(image)
118
- draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
119
- return image