File size: 4,691 Bytes
4c954ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from contextlib import contextmanager
import logging
import os

from matplotlib.pyplot import figimage, margins
import numpy as np
import cv2 

try:
    import matplotlib.pyplot as plt  # pylint: disable=import-error
    
except ModuleNotFoundError as err:
    if err.name != 'matplotlib':
        raise err
    plt = None


LOG = logging.getLogger(__name__)

class Canvas:
    """Canvas for plotting.
    All methods expose Axes objects. To get Figure objects, you can ask the axis
    `ax.get_figure()`.
    """

    all_images_directory = None
    all_images_count = 0
    show = False
    image_width = 7.0
    image_height = None
    blank_dpi = 200
    image_dpi_factor = 1.0
    image_min_dpi = 50.0
    out_file_extension = 'pdf'
    white_overlay = False

    @classmethod
    def generic_name(cls):
        if cls.all_images_directory is None:
            return None
        os.makedirs(cls.all_images_directory, exist_ok=True)

        cls.all_images_count += 1
        return os.path.join(cls.all_images_directory,
                            '{:04}.{}'.format(cls.all_images_count, cls.out_file_extension))
    
    @classmethod
    @contextmanager
    def blank(cls, fig_file=None, *, dpi=None, nomargin=False, **kwargs):
        if plt is None:
            raise Exception('please install matplotlib')
        if fig_file is None:
            fig_file = cls.generic_name()

        if dpi is None:
            dpi = cls.blank_dpi

        if 'figsize' not in kwargs:
            kwargs['figsize'] = (10, 6)
        
        if nomargin:
            if 'gridspec_kw' not in kwargs:
                kwargs['gridspec_kw'] = {}
            kwargs['gridspec_kw']['wspace'] = 0
            kwargs['gridspec_kw']['hspace'] = 0
            kwargs['gridspec_kw']['left'] = 0.0
            kwargs['gridspec_kw']['right'] = 1.0
            kwargs['gridspec_kw']['top'] = 1.0
            kwargs['gridspec_kw']['bottom'] = 0.0
        
        fig, ax = plt.subplots(dpi=dpi, **kwargs)

        yield ax

        fig.set_tight_layout(not margins)
        if fig_file:
            LOG.debug('writing image to %s', fig_file)
            fig.savefig(fig_file)

        if cls.show:
            plt.show()
        plt.close(fig)
    

    @classmethod
    @contextmanager
    def image(cls, image, fig_file=None, *, margin=None, **kwargs):
        if plt is None:
            raise Exception('please install matplotlib')
        if fig_file is None:
            fig_file = cls.generic_name()

        if isinstance(image, str):
            image = cv2.imread(image)[...,::-1]
        else:
            image = np.asarray(image)

        if margin is None:
            margin = [0.0, 0.0, 0.0, 0.0]
        elif isinstance(margin, float):
            margin = [margin, margin, margin, margin]
        assert len(margin) == 4

        if 'figsize' not in kwargs:
            # compute figure size: use image ratio and take the drawable area
            # into account that is left after subtracting margins.
            image_ratio = image.shape[0] / image.shape[1]
            image_area_ratio = (1.0 - margin[1] - margin[3]) / (1.0 - margin[0] - margin[2])
            if cls.image_width is not None:
                kwargs['figsize'] = (
                    cls.image_width,
                    cls.image_width * image_ratio / image_area_ratio
                )
            elif cls.image_height:
                kwargs['figsize'] = (
                    cls.image_height * image_area_ratio / image_ratio,
                    cls.image_height
                )

        # dpi = max(cls.image_min_dpi, image.shape[1] / kwargs['figsize'][0] * cls.image_dpi_factor)
        dpi = 200
        # import pdb; pdb.set_trace()
        fig = plt.figure(dpi=dpi, **kwargs)
        ax = plt.Axes(fig, [0.0 + margin[0],
                            0.0 + margin[1],
                            1.0 - margin[2],
                            1.0 - margin[3]])

        ax.set_axis_off()
        ax.set_xlim(-0.5, image.shape[1] - 0.5)  # imshow uses center-pixel-coordinates
        ax.set_ylim(image.shape[0] - 0.5, -0.5)
        fig.add_axes(ax)
        ax.imshow(image)
        if cls.white_overlay:
            white_screen(ax, cls.white_overlay)
        yield ax

        if fig_file:
            LOG.debug('writing image to %s', fig_file)
            fig.savefig(fig_file)
        if cls.show:
            plt.show()
            import pdb;pdb.set_trace()
        plt.close(fig)

def white_screen(ax, alpha=0.9):
    ax.add_patch(
        plt.Rectangle((0, 0), 1, 1, transform=ax.transAxes, alpha=alpha,
                      facecolor='white')
    )

canvas = Canvas.blank
image_canvas = Canvas.image