File size: 2,736 Bytes
bb3e610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import os
from custom.clip_ebc_onnx import ClipEBCOnnx

def parse_args():
    parser = argparse.ArgumentParser(description='CLIP-EBC Crowd Counting (ONNX)')
    parser.add_argument('--image', required=True, help='Path to input image')
    parser.add_argument('--model', default='assets/CLIP_EBC_nwpu_rmse_onnx.onnx', help='Path to ONNX model')
    parser.add_argument('--visualize', choices=['density', 'dots', 'all', 'none'], 
                        default='none', help='Visualization type')
    parser.add_argument('--save', action='store_true', 
                        help='Save visualization results')
    parser.add_argument('--output-dir', default='results', 
                        help='Directory to save results')
    
    # ์‹œ๊ฐํ™” ๊ด€๋ จ ๋งค๊ฐœ๋ณ€์ˆ˜
    parser.add_argument('--alpha', type=float, default=0.5, 
                        help='Alpha value for density map')
    parser.add_argument('--dot-size', type=int, default=20, 
                        help='Dot size for dot visualization')
    parser.add_argument('--sigma', type=float, default=1, 
                        help='Sigma value for Gaussian filter')
    parser.add_argument('--percentile', type=float, default=97, 
                        help='Percentile threshold for dot visualization')
    

    
    return parser.parse_args()

def main():
    args = parse_args()
    
    # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” - ONNX ๋ฒ„์ „
    model = ClipEBCOnnx(
        onnx_model_path=args.model
    )
    
    # ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
    if args.save:
        os.makedirs(args.output_dir, exist_ok=True)
    
    # ์˜ˆ์ธก ์ˆ˜ํ–‰
    count = model.predict(args.image)
    print(f"์˜ˆ์ธก๋œ ๊ตฐ์ค‘ ์ˆ˜: {count:.2f}")
    
    # ์‹œ๊ฐํ™”
    if args.visualize in ['density', 'all']:
        save_path = os.path.join(args.output_dir, 'density_map.png') if args.save else None
        fig, density_map = model.visualize_density_map(
            alpha=args.alpha,
            save=args.save,
            save_path=save_path
        )
    
    if args.visualize in ['dots', 'all']:
        save_path = os.path.join(args.output_dir, 'dot_map.png') if args.save else None
        canvas, dot_map = model.visualize_dots(
            dot_size=args.dot_size,
            sigma=args.sigma,
            percentile=args.percentile,
            save=args.save,
            save_path=save_path
        )
        
        # matplotlib figure ๋‹ซ๊ธฐ (๋ฉ”๋ชจ๋ฆฌ ๋ˆ„์ˆ˜ ๋ฐฉ์ง€)
        if args.visualize in ['density', 'all']:
            import matplotlib.pyplot as plt
            plt.close(fig)
        
        if args.visualize in ['dots', 'all']:
            import matplotlib.pyplot as plt
            plt.close(canvas.figure)

if __name__ == "__main__":
    main()