File size: 2,574 Bytes
4c954ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import logging

import numpy as np
import torch 


try:
    import matplotlib
    import matplotlib.animation
    import matplotlib.collections
    import matplotlib.patches
except ImportError:
    matplotlib = None


LOG = logging.getLogger(__name__)


class HAWPainter:
    # line_width = None
    # marker_size = None
    line_width = 2
    marker_size = 4

    confidence_threshold = 0.05

    def __init__(self):

        if self.line_width is None:
            self.line_width = 1
        
        if self.marker_size is None:
            self.marker_size = max(1, int(self.line_width * 0.5))

    def draw_junctions(self, ax, wireframe, *,
            edge_color = None, vertex_color = None):
        if wireframe is None:
            return
        
        if edge_color is None:
            edge_color = 'b'
        if vertex_color is None:
            vertex_color = 'c'
        
        if 'lines_score' in wireframe.keys():
            line_segments = wireframe['lines_pred'][wireframe['lines_score']>self.confidence_threshold]
        else:
            line_segments = wireframe['lines_pred']

        if isinstance(line_segments, torch.Tensor):
            line_segments = line_segments.cpu().numpy()

        ax.plot(line_segments[:,0],line_segments[:,1],'.',color=vertex_color)
        ax.plot(line_segments[:,2],line_segments[:,3],'.',
        color=vertex_color)
    def draw_wireframe(self, ax, wireframe, *,
            edge_color = None, vertex_color = None):
        if wireframe is None:
            return
        
        if edge_color is None:
            edge_color = 'b'
        if vertex_color is None:
            vertex_color = 'c'
        
        if 'lines_score' in wireframe.keys():
            line_segments = wireframe['lines_pred'][wireframe['lines_score']>self.confidence_threshold]
        else:
            line_segments = wireframe['lines_pred']

        # import pdb;pdb.set_trace()    
        if isinstance(line_segments, torch.Tensor):
            line_segments = line_segments.cpu().numpy()

        # import pdb;pdb.set_trace()
        # line_segments = wireframe.line_segments(threshold=self.confidence_threshold)
        # line_segments = line_segments.cpu().numpy()
        ax.plot([line_segments[:,0],line_segments[:,2]],[line_segments[:,1],line_segments[:,3]],'-',color=edge_color,linewidth=self.line_width)
        ax.plot(line_segments[:,0],line_segments[:,1],'.',color=vertex_color,markersize=self.marker_size)
        ax.plot(line_segments[:,2],line_segments[:,3],'.',color=vertex_color,markersize=self.marker_size)