File size: 9,493 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

import numpy as np
from svgpathtools import (
    Path, Arc, CubicBezier, QuadraticBezier,
    svgstr2paths)
import os 
from noise import pnoise1
import re
import matplotlib.colors as mcolors
from bs4 import BeautifulSoup
from starvector.data.util import rasterize_svg

class SVGTransforms:
    def __init__(self, transformations):
        self.transformations = transformations
        self.noise_std = self.transformations.get('noise_std', False) 
        self.noise_type = self.transformations.get('noise_type', False)
        self.rotate = self.transformations.get('rotate', False)
        self.shift_re = self.transformations.get('shift_re', False)
        self.shift_im = self.transformations.get('shift_im', False)
        self.scale = self.transformations.get('scale', False)
        self.color_noise = self.transformations.get('color_noise', False)
        self.p = self.transformations.get('p', 0.5)
        self.color_change = self.transformations.get('color_change', False)
        self.colors = self.transformations.get('colors', ['#ff0000', '#0000ff', '#000000'])

    def sample_transformations(self):
        if self.rotate:
            a, b = self.rotate['from'], self.rotate['to']
            rotation_angle = np.random.uniform(a, b)
            self.rotation_angle = rotation_angle

        if self.shift_re or self.shift_im:
            self.shift_real = np.random.uniform(self.shift_re['from'], self.shift_re['to'])
            self.shift_imag = np.random.uniform(self.shift_im['from'], self.shift_im['to'])

        if self.scale:
            self.scale = np.random.uniform(self.scale['from'], self.scale['to'])

        if self.color_noise:
            self.color_noise_std = np.random.uniform(self.color_noise['from'], self.color_noise['to'])


    def paths2str(self, groupped_paths, svg_opening_tag='<svg xmlns="http://www.w3.org/2000/svg" version="1.1">'):
        
        keys_to_exclude = ['d', 'cx', 'cy', 'rx', 'ry']
        all_groups_srt = ''
        for group, elements in groupped_paths.items():
            group_attributes, paths_and_attributes = elements.get('attrs', {}), elements.get('paths', [])
            group_attr_str = ' '.join(f'{key}="{value}"' for key, value in group_attributes.items())
            path_strings = []
            path_str = ''
            for path, attributes in paths_and_attributes:
                path_attr_str = ''
                d_str = path.d()
                
                for key, value in attributes.items():
                    if key not in keys_to_exclude:
                        path_attr_str += f' {key}="{value}"'

                path_strings.append(f'<path d="{d_str}"{path_attr_str} />')
            path_str = "\n".join(path_strings)
            if 'no_group'in group:
                group_str = path_str
            else:
                group_str = f'<g {group_attr_str}>\n{path_str}\n</g>\n'
            all_groups_srt += group_str
        svg = f'{svg_opening_tag}\n{all_groups_srt}</svg>'
        return svg
    
    def add_noise(self, seg):        
        noise_scale = np.random.uniform(self.noise_std['from'], self.noise_std['to'])
        if self.noise_type == 'gaussian':
            noise_sample = np.random.normal(loc=0.0, scale=noise_scale) + \
                        1j * np.random.normal(loc=0.0, scale=noise_scale)
        elif self.noise_type == 'perlin':
            noise_sample = complex(pnoise1(np.random.random(), octaves=2), pnoise1(np.random.random(), octaves=2))*noise_scale

        if isinstance(seg, CubicBezier):
            seg.control1 = seg.control1 + noise_sample
            seg.control2 = seg.control2 + noise_sample
        elif isinstance(seg, QuadraticBezier):
            seg.control = seg.control + noise_sample
        elif isinstance(seg, Arc):
            seg.radius = seg.radius + noise_sample

                
        return seg
    
    def do_rotate(self, path, viewbox_width, viewbox_height):
        if self.rotate:
            new_path = path.rotated(self.rotation_angle, complex(viewbox_width/2, viewbox_height/2))
            return new_path
        else:
            return path

    def do_shift(self, path):
        if self.shift_re or self.shift_im:
            return path.translated(complex(self.shift_real, self.shift_imag))
        else:
            return path

    def do_scale(self, path):
        if self.scale:
            return path.scaled(self.scale)
        else:
            return path
    
    def add_color_noise(self, source_color):
         # Convert color to RGB 
        if source_color.startswith("#"):
            base_color = mcolors.hex2color(source_color)
        else:
            base_color = mcolors.hex2color(mcolors.CSS4_COLORS.get(source_color, '#FFFFFF'))

        # Add noise to each RGB component
        noise = np.random.normal(0, self.color_noise_std, 3)
        noisy_color = np.clip(np.array(base_color) + noise, 0, 1)

        # Convert the RGB color back to hex
        hex_color = mcolors.rgb2hex(noisy_color)

        return hex_color

    def do_color_change(self, attr):
        if 'fill' in attr:
            if self.color_noise or self.color_change:
                fill_value = attr['fill']    
                if fill_value == 'none':
                    new_fill_value = 'none'
                else:
                    if self.color_noise:
                        new_fill_value = self.add_color_noise(fill_value)
                    elif self.color_change:
                        new_fill_value = np.random.choice(self.colors)
                attr['fill'] = new_fill_value
        return attr
    
    def clean_attributes(self, attr):
        attr_out = {}
        if 'fill' in attr:
            attr_out = attr
        elif 'style' in attr:
            fill_values = re.findall('fill:[^;]+', attr['style'])
            if fill_values:
                fill_value = fill_values[0].replace('fill:', '').strip()
                attr_out['fill'] = fill_value
            else:
                attr_out = attr
        else:
            attr_out = attr

        return attr_out

    def get_viewbox_size(self, svg):
        # Try to extract viewBox attribute
        match = re.search(r'viewBox="([^"]+)"', svg)
        if match:
            viewbox = match.group(1)
        else:
            # If viewBox is not found, try to extract width and height attributes
            match = re.search(r'width="([^"]+)px" height="([^"]+)px"', svg)
            if match:
                width, height = match.groups()
                viewbox = f"0 0 {width} {height}"
            else:
                viewbox = "0 0 256 256"  # Default if neither viewBox nor width/height are found
    
        viewbox = [float(x) for x in viewbox.split()]
        viewbox_width, viewbox_height = viewbox[2], viewbox[3]
        return viewbox_width, viewbox_height

    def augment(self, svg):
        if os.path.isfile(svg):
            # open svg file
            with open(svg, 'r') as f:
                svg = f.read()
                
        # Sample transformations for this sample
        self.sample_transformations() 


        # Parse the SVG content
        soup = BeautifulSoup(svg, 'xml')

        # Get opening tag
        svg_opening_tag = re.findall('<svg[^>]+>', svg)[0]

        viewbox_width, viewbox_height = self.get_viewbox_size(svg)

        # Get all svg parents
        groups = soup.findAll()
        
        # Create the groups of paths based on their original <g> tag
        grouped_paths = {}
        for i, g in enumerate(groups):
            if g.name == 'g':
                group_id = group_id = g.get('id') if g.get('id') else f'none_{i}'
                group_attrs = g.attrs

            elif g.name == 'svg' or g.name == 'metadata' or g.name == 'defs':
                continue
            
            else:
                group_id = f'no_group_{i}'
                group_attrs = {}
            
            group_svg_string = f'{svg_opening_tag}{str(g)}</svg>'
            try:
                paths, attributes = svgstr2paths(group_svg_string)
            except:
                return svg, rasterize_svg(svg)
            if not paths:
                continue

            paths_and_attributes = []

            # Rotation, shift, scale, noise addition
            new_paths = []
            new_attributes = []
            for path, attribute in zip(paths, attributes):
                attr = self.clean_attributes(attribute)
                
                new_path = self.do_rotate(path, viewbox_width, viewbox_height)
                new_path = self.do_shift(new_path)
                new_path = self.do_scale(new_path)
                
                if self.noise_std:
                    # Add noise to path to deform svg
                    noisy_path = []
                    for seg in new_path:
                        noisy_seg = self.add_noise(seg)
                        noisy_path.append(noisy_seg)
                    new_paths.append(Path(*noisy_path))
                else: 
                    new_paths.append(new_path)

                # Color change
                attr = self.do_color_change(attr)
                paths_and_attributes.append((new_path, attr))
            
            grouped_paths[group_id] = {
                'paths': paths_and_attributes, 
                'attrs': group_attrs
                }

        svg = self.paths2str(grouped_paths, svg_opening_tag)
        image = rasterize_svg(svg)

        return svg, image