Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
from custom.clip_ebc_onnx import ClipEBCOnnx | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='CLIP-EBC Crowd Counting (ONNX)') | |
parser.add_argument('--image', required=True, help='Path to input image') | |
parser.add_argument('--model', default='assets/CLIP_EBC_nwpu_rmse_onnx.onnx', help='Path to ONNX model') | |
parser.add_argument('--visualize', choices=['density', 'dots', 'all', 'none'], | |
default='none', help='Visualization type') | |
parser.add_argument('--save', action='store_true', | |
help='Save visualization results') | |
parser.add_argument('--output-dir', default='results', | |
help='Directory to save results') | |
# ์๊ฐํ ๊ด๋ จ ๋งค๊ฐ๋ณ์ | |
parser.add_argument('--alpha', type=float, default=0.5, | |
help='Alpha value for density map') | |
parser.add_argument('--dot-size', type=int, default=20, | |
help='Dot size for dot visualization') | |
parser.add_argument('--sigma', type=float, default=1, | |
help='Sigma value for Gaussian filter') | |
parser.add_argument('--percentile', type=float, default=97, | |
help='Percentile threshold for dot visualization') | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
# ๋ชจ๋ธ ์ด๊ธฐํ - ONNX ๋ฒ์ | |
model = ClipEBCOnnx( | |
onnx_model_path=args.model | |
) | |
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ | |
if args.save: | |
os.makedirs(args.output_dir, exist_ok=True) | |
# ์์ธก ์ํ | |
count = model.predict(args.image) | |
print(f"์์ธก๋ ๊ตฐ์ค ์: {count:.2f}") | |
# ์๊ฐํ | |
if args.visualize in ['density', 'all']: | |
save_path = os.path.join(args.output_dir, 'density_map.png') if args.save else None | |
fig, density_map = model.visualize_density_map( | |
alpha=args.alpha, | |
save=args.save, | |
save_path=save_path | |
) | |
if args.visualize in ['dots', 'all']: | |
save_path = os.path.join(args.output_dir, 'dot_map.png') if args.save else None | |
canvas, dot_map = model.visualize_dots( | |
dot_size=args.dot_size, | |
sigma=args.sigma, | |
percentile=args.percentile, | |
save=args.save, | |
save_path=save_path | |
) | |
# matplotlib figure ๋ซ๊ธฐ (๋ฉ๋ชจ๋ฆฌ ๋์ ๋ฐฉ์ง) | |
if args.visualize in ['density', 'all']: | |
import matplotlib.pyplot as plt | |
plt.close(fig) | |
if args.visualize in ['dots', 'all']: | |
import matplotlib.pyplot as plt | |
plt.close(canvas.figure) | |
if __name__ == "__main__": | |
main() |