test_ebc / main.py
piaspace's picture
[first]
bb3e610
import argparse
import os
from custom.clip_ebc_onnx import ClipEBCOnnx
def parse_args():
parser = argparse.ArgumentParser(description='CLIP-EBC Crowd Counting (ONNX)')
parser.add_argument('--image', required=True, help='Path to input image')
parser.add_argument('--model', default='assets/CLIP_EBC_nwpu_rmse_onnx.onnx', help='Path to ONNX model')
parser.add_argument('--visualize', choices=['density', 'dots', 'all', 'none'],
default='none', help='Visualization type')
parser.add_argument('--save', action='store_true',
help='Save visualization results')
parser.add_argument('--output-dir', default='results',
help='Directory to save results')
# ์‹œ๊ฐํ™” ๊ด€๋ จ ๋งค๊ฐœ๋ณ€์ˆ˜
parser.add_argument('--alpha', type=float, default=0.5,
help='Alpha value for density map')
parser.add_argument('--dot-size', type=int, default=20,
help='Dot size for dot visualization')
parser.add_argument('--sigma', type=float, default=1,
help='Sigma value for Gaussian filter')
parser.add_argument('--percentile', type=float, default=97,
help='Percentile threshold for dot visualization')
return parser.parse_args()
def main():
args = parse_args()
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™” - ONNX ๋ฒ„์ „
model = ClipEBCOnnx(
onnx_model_path=args.model
)
# ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
if args.save:
os.makedirs(args.output_dir, exist_ok=True)
# ์˜ˆ์ธก ์ˆ˜ํ–‰
count = model.predict(args.image)
print(f"์˜ˆ์ธก๋œ ๊ตฐ์ค‘ ์ˆ˜: {count:.2f}")
# ์‹œ๊ฐํ™”
if args.visualize in ['density', 'all']:
save_path = os.path.join(args.output_dir, 'density_map.png') if args.save else None
fig, density_map = model.visualize_density_map(
alpha=args.alpha,
save=args.save,
save_path=save_path
)
if args.visualize in ['dots', 'all']:
save_path = os.path.join(args.output_dir, 'dot_map.png') if args.save else None
canvas, dot_map = model.visualize_dots(
dot_size=args.dot_size,
sigma=args.sigma,
percentile=args.percentile,
save=args.save,
save_path=save_path
)
# matplotlib figure ๋‹ซ๊ธฐ (๋ฉ”๋ชจ๋ฆฌ ๋ˆ„์ˆ˜ ๋ฐฉ์ง€)
if args.visualize in ['density', 'all']:
import matplotlib.pyplot as plt
plt.close(fig)
if args.visualize in ['dots', 'all']:
import matplotlib.pyplot as plt
plt.close(canvas.figure)
if __name__ == "__main__":
main()