markany-yhkwon commited on
Commit
496f961
·
1 Parent(s): 2f32b9e

_c bug fix

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. setup.py +220 -0
app.py CHANGED
@@ -15,7 +15,7 @@ from pathlib import Path
15
  import warnings
16
 
17
  import torch
18
-
19
  warnings.filterwarnings("ignore")
20
 
21
  from groundingdino.models import build_model
 
15
  import warnings
16
 
17
  import torch
18
+ os.system("python setup.py build develop --user")
19
  warnings.filterwarnings("ignore")
20
 
21
  from groundingdino.models import build_model
setup.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The IDEA Authors. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------------------------
16
+ # Modified from
17
+ # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/setup.py
18
+ # https://github.com/facebookresearch/detectron2/blob/main/setup.py
19
+ # https://github.com/open-mmlab/mmdetection/blob/master/setup.py
20
+ # https://github.com/Oneflow-Inc/libai/blob/main/setup.py
21
+ # ------------------------------------------------------------------------------------------------
22
+
23
+ import glob
24
+ import os
25
+ import subprocess
26
+
27
+ import subprocess
28
+ import sys
29
+
30
+ def install_torch():
31
+ try:
32
+ import torch
33
+ except ImportError:
34
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
35
+
36
+ # Call the function to ensure torch is installed
37
+ install_torch()
38
+
39
+ import torch
40
+ from setuptools import find_packages, setup
41
+ from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
42
+
43
+ # groundingdino version info
44
+ version = "0.1.0"
45
+ package_name = "groundingdino"
46
+ cwd = os.path.dirname(os.path.abspath(__file__))
47
+
48
+
49
+ sha = "Unknown"
50
+ try:
51
+ sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
52
+ except Exception:
53
+ pass
54
+
55
+
56
+ def write_version_file():
57
+ version_path = os.path.join(cwd, "groundingdino", "version.py")
58
+ with open(version_path, "w") as f:
59
+ f.write(f"__version__ = '{version}'\n")
60
+ # f.write(f"git_version = {repr(sha)}\n")
61
+
62
+
63
+ requirements = ["torch", "torchvision"]
64
+
65
+ torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
66
+
67
+
68
+ def get_extensions():
69
+ this_dir = os.path.dirname(os.path.abspath(__file__))
70
+ extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc")
71
+
72
+ main_source = os.path.join(extensions_dir, "vision.cpp")
73
+ sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
74
+ source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
75
+ os.path.join(extensions_dir, "*.cu")
76
+ )
77
+
78
+ sources = [main_source] + sources
79
+
80
+ extension = CppExtension
81
+
82
+ extra_compile_args = {"cxx": []}
83
+ define_macros = []
84
+
85
+ if CUDA_HOME is not None and (torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ):
86
+ print("Compiling with CUDA")
87
+ extension = CUDAExtension
88
+ sources += source_cuda
89
+ define_macros += [("WITH_CUDA", None)]
90
+ extra_compile_args["nvcc"] = [
91
+ "-DCUDA_HAS_FP16=1",
92
+ "-D__CUDA_NO_HALF_OPERATORS__",
93
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
94
+ "-D__CUDA_NO_HALF2_OPERATORS__",
95
+ ]
96
+ else:
97
+ print("Compiling without CUDA")
98
+ define_macros += [("WITH_HIP", None)]
99
+ extra_compile_args["nvcc"] = []
100
+ return None
101
+
102
+ sources = [os.path.join(extensions_dir, s) for s in sources]
103
+ include_dirs = [extensions_dir]
104
+
105
+ ext_modules = [
106
+ extension(
107
+ "groundingdino._C",
108
+ sources,
109
+ include_dirs=include_dirs,
110
+ define_macros=define_macros,
111
+ extra_compile_args=extra_compile_args,
112
+ )
113
+ ]
114
+
115
+ return ext_modules
116
+
117
+
118
+ def parse_requirements(fname="requirements.txt", with_version=True):
119
+ """Parse the package dependencies listed in a requirements file but strips
120
+ specific versioning information.
121
+
122
+ Args:
123
+ fname (str): path to requirements file
124
+ with_version (bool, default=False): if True include version specs
125
+
126
+ Returns:
127
+ List[str]: list of requirements items
128
+
129
+ CommandLine:
130
+ python -c "import setup; print(setup.parse_requirements())"
131
+ """
132
+ import re
133
+ import sys
134
+ from os.path import exists
135
+
136
+ require_fpath = fname
137
+
138
+ def parse_line(line):
139
+ """Parse information from a line in a requirements text file."""
140
+ if line.startswith("-r "):
141
+ # Allow specifying requirements in other files
142
+ target = line.split(" ")[1]
143
+ for info in parse_require_file(target):
144
+ yield info
145
+ else:
146
+ info = {"line": line}
147
+ if line.startswith("-e "):
148
+ info["package"] = line.split("#egg=")[1]
149
+ elif "@git+" in line:
150
+ info["package"] = line
151
+ else:
152
+ # Remove versioning from the package
153
+ pat = "(" + "|".join([">=", "==", ">"]) + ")"
154
+ parts = re.split(pat, line, maxsplit=1)
155
+ parts = [p.strip() for p in parts]
156
+
157
+ info["package"] = parts[0]
158
+ if len(parts) > 1:
159
+ op, rest = parts[1:]
160
+ if ";" in rest:
161
+ # Handle platform specific dependencies
162
+ # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
163
+ version, platform_deps = map(str.strip, rest.split(";"))
164
+ info["platform_deps"] = platform_deps
165
+ else:
166
+ version = rest # NOQA
167
+ info["version"] = (op, version)
168
+ yield info
169
+
170
+ def parse_require_file(fpath):
171
+ with open(fpath, "r") as f:
172
+ for line in f.readlines():
173
+ line = line.strip()
174
+ if line and not line.startswith("#"):
175
+ for info in parse_line(line):
176
+ yield info
177
+
178
+ def gen_packages_items():
179
+ if exists(require_fpath):
180
+ for info in parse_require_file(require_fpath):
181
+ parts = [info["package"]]
182
+ if with_version and "version" in info:
183
+ parts.extend(info["version"])
184
+ if not sys.version.startswith("3.4"):
185
+ # apparently package_deps are broken in 3.4
186
+ platform_deps = info.get("platform_deps")
187
+ if platform_deps is not None:
188
+ parts.append(";" + platform_deps)
189
+ item = "".join(parts)
190
+ yield item
191
+
192
+ packages = list(gen_packages_items())
193
+ return packages
194
+
195
+
196
+ if __name__ == "__main__":
197
+ print(f"Building wheel {package_name}-{version}")
198
+
199
+ with open("LICENSE", "r", encoding="utf-8") as f:
200
+ license = f.read()
201
+
202
+ write_version_file()
203
+
204
+ setup(
205
+ name="groundingdino",
206
+ version="0.1.0",
207
+ author="International Digital Economy Academy, Shilong Liu",
208
+ url="https://github.com/IDEA-Research/GroundingDINO",
209
+ description="open-set object detector",
210
+ license=license,
211
+ install_requires=parse_requirements("requirements.txt"),
212
+ packages=find_packages(
213
+ exclude=(
214
+ "configs",
215
+ "tests",
216
+ )
217
+ ),
218
+ ext_modules=get_extensions(),
219
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
220
+ )