Spaces:
Runtime error
Runtime error
File size: 2,736 Bytes
bb3e610 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import argparse
import os
from custom.clip_ebc_onnx import ClipEBCOnnx
def parse_args():
parser = argparse.ArgumentParser(description='CLIP-EBC Crowd Counting (ONNX)')
parser.add_argument('--image', required=True, help='Path to input image')
parser.add_argument('--model', default='assets/CLIP_EBC_nwpu_rmse_onnx.onnx', help='Path to ONNX model')
parser.add_argument('--visualize', choices=['density', 'dots', 'all', 'none'],
default='none', help='Visualization type')
parser.add_argument('--save', action='store_true',
help='Save visualization results')
parser.add_argument('--output-dir', default='results',
help='Directory to save results')
# ์๊ฐํ ๊ด๋ จ ๋งค๊ฐ๋ณ์
parser.add_argument('--alpha', type=float, default=0.5,
help='Alpha value for density map')
parser.add_argument('--dot-size', type=int, default=20,
help='Dot size for dot visualization')
parser.add_argument('--sigma', type=float, default=1,
help='Sigma value for Gaussian filter')
parser.add_argument('--percentile', type=float, default=97,
help='Percentile threshold for dot visualization')
return parser.parse_args()
def main():
args = parse_args()
# ๋ชจ๋ธ ์ด๊ธฐํ - ONNX ๋ฒ์
model = ClipEBCOnnx(
onnx_model_path=args.model
)
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
if args.save:
os.makedirs(args.output_dir, exist_ok=True)
# ์์ธก ์ํ
count = model.predict(args.image)
print(f"์์ธก๋ ๊ตฐ์ค ์: {count:.2f}")
# ์๊ฐํ
if args.visualize in ['density', 'all']:
save_path = os.path.join(args.output_dir, 'density_map.png') if args.save else None
fig, density_map = model.visualize_density_map(
alpha=args.alpha,
save=args.save,
save_path=save_path
)
if args.visualize in ['dots', 'all']:
save_path = os.path.join(args.output_dir, 'dot_map.png') if args.save else None
canvas, dot_map = model.visualize_dots(
dot_size=args.dot_size,
sigma=args.sigma,
percentile=args.percentile,
save=args.save,
save_path=save_path
)
# matplotlib figure ๋ซ๊ธฐ (๋ฉ๋ชจ๋ฆฌ ๋์ ๋ฐฉ์ง)
if args.visualize in ['density', 'all']:
import matplotlib.pyplot as plt
plt.close(fig)
if args.visualize in ['dots', 'all']:
import matplotlib.pyplot as plt
plt.close(canvas.figure)
if __name__ == "__main__":
main() |