Spaces:
Paused
Paused
| # Protocol Buffers - Google's data interchange format | |
| # Copyright 2008 Google Inc. All rights reserved. | |
| # | |
| # Use of this source code is governed by a BSD-style | |
| # license that can be found in the LICENSE file or at | |
| # https://developers.google.com/open-source/licenses/bsd | |
| """Contains FieldMask class.""" | |
| from google.protobuf.descriptor import FieldDescriptor | |
| class FieldMask(object): | |
| """Class for FieldMask message type.""" | |
| __slots__ = () | |
| def ToJsonString(self): | |
| """Converts FieldMask to string according to proto3 JSON spec.""" | |
| camelcase_paths = [] | |
| for path in self.paths: | |
| camelcase_paths.append(_SnakeCaseToCamelCase(path)) | |
| return ','.join(camelcase_paths) | |
| def FromJsonString(self, value): | |
| """Converts string to FieldMask according to proto3 JSON spec.""" | |
| if not isinstance(value, str): | |
| raise ValueError('FieldMask JSON value not a string: {!r}'.format(value)) | |
| self.Clear() | |
| if value: | |
| for path in value.split(','): | |
| self.paths.append(_CamelCaseToSnakeCase(path)) | |
| def IsValidForDescriptor(self, message_descriptor): | |
| """Checks whether the FieldMask is valid for Message Descriptor.""" | |
| for path in self.paths: | |
| if not _IsValidPath(message_descriptor, path): | |
| return False | |
| return True | |
| def AllFieldsFromDescriptor(self, message_descriptor): | |
| """Gets all direct fields of Message Descriptor to FieldMask.""" | |
| self.Clear() | |
| for field in message_descriptor.fields: | |
| self.paths.append(field.name) | |
| def CanonicalFormFromMask(self, mask): | |
| """Converts a FieldMask to the canonical form. | |
| Removes paths that are covered by another path. For example, | |
| "foo.bar" is covered by "foo" and will be removed if "foo" | |
| is also in the FieldMask. Then sorts all paths in alphabetical order. | |
| Args: | |
| mask: The original FieldMask to be converted. | |
| """ | |
| tree = _FieldMaskTree(mask) | |
| tree.ToFieldMask(self) | |
| def Union(self, mask1, mask2): | |
| """Merges mask1 and mask2 into this FieldMask.""" | |
| _CheckFieldMaskMessage(mask1) | |
| _CheckFieldMaskMessage(mask2) | |
| tree = _FieldMaskTree(mask1) | |
| tree.MergeFromFieldMask(mask2) | |
| tree.ToFieldMask(self) | |
| def Intersect(self, mask1, mask2): | |
| """Intersects mask1 and mask2 into this FieldMask.""" | |
| _CheckFieldMaskMessage(mask1) | |
| _CheckFieldMaskMessage(mask2) | |
| tree = _FieldMaskTree(mask1) | |
| intersection = _FieldMaskTree() | |
| for path in mask2.paths: | |
| tree.IntersectPath(path, intersection) | |
| intersection.ToFieldMask(self) | |
| def MergeMessage( | |
| self, source, destination, | |
| replace_message_field=False, replace_repeated_field=False): | |
| """Merges fields specified in FieldMask from source to destination. | |
| Args: | |
| source: Source message. | |
| destination: The destination message to be merged into. | |
| replace_message_field: Replace message field if True. Merge message | |
| field if False. | |
| replace_repeated_field: Replace repeated field if True. Append | |
| elements of repeated field if False. | |
| """ | |
| tree = _FieldMaskTree(self) | |
| tree.MergeMessage( | |
| source, destination, replace_message_field, replace_repeated_field) | |
| def _IsValidPath(message_descriptor, path): | |
| """Checks whether the path is valid for Message Descriptor.""" | |
| parts = path.split('.') | |
| last = parts.pop() | |
| for name in parts: | |
| field = message_descriptor.fields_by_name.get(name) | |
| if (field is None or | |
| field.label == FieldDescriptor.LABEL_REPEATED or | |
| field.type != FieldDescriptor.TYPE_MESSAGE): | |
| return False | |
| message_descriptor = field.message_type | |
| return last in message_descriptor.fields_by_name | |
| def _CheckFieldMaskMessage(message): | |
| """Raises ValueError if message is not a FieldMask.""" | |
| message_descriptor = message.DESCRIPTOR | |
| if (message_descriptor.name != 'FieldMask' or | |
| message_descriptor.file.name != 'google/protobuf/field_mask.proto'): | |
| raise ValueError('Message {0} is not a FieldMask.'.format( | |
| message_descriptor.full_name)) | |
| def _SnakeCaseToCamelCase(path_name): | |
| """Converts a path name from snake_case to camelCase.""" | |
| result = [] | |
| after_underscore = False | |
| for c in path_name: | |
| if c.isupper(): | |
| raise ValueError( | |
| 'Fail to print FieldMask to Json string: Path name ' | |
| '{0} must not contain uppercase letters.'.format(path_name)) | |
| if after_underscore: | |
| if c.islower(): | |
| result.append(c.upper()) | |
| after_underscore = False | |
| else: | |
| raise ValueError( | |
| 'Fail to print FieldMask to Json string: The ' | |
| 'character after a "_" must be a lowercase letter ' | |
| 'in path name {0}.'.format(path_name)) | |
| elif c == '_': | |
| after_underscore = True | |
| else: | |
| result += c | |
| if after_underscore: | |
| raise ValueError('Fail to print FieldMask to Json string: Trailing "_" ' | |
| 'in path name {0}.'.format(path_name)) | |
| return ''.join(result) | |
| def _CamelCaseToSnakeCase(path_name): | |
| """Converts a field name from camelCase to snake_case.""" | |
| result = [] | |
| for c in path_name: | |
| if c == '_': | |
| raise ValueError('Fail to parse FieldMask: Path name ' | |
| '{0} must not contain "_"s.'.format(path_name)) | |
| if c.isupper(): | |
| result += '_' | |
| result += c.lower() | |
| else: | |
| result += c | |
| return ''.join(result) | |
| class _FieldMaskTree(object): | |
| """Represents a FieldMask in a tree structure. | |
| For example, given a FieldMask "foo.bar,foo.baz,bar.baz", | |
| the FieldMaskTree will be: | |
| [_root] -+- foo -+- bar | |
| | | | |
| | +- baz | |
| | | |
| +- bar --- baz | |
| In the tree, each leaf node represents a field path. | |
| """ | |
| __slots__ = ('_root',) | |
| def __init__(self, field_mask=None): | |
| """Initializes the tree by FieldMask.""" | |
| self._root = {} | |
| if field_mask: | |
| self.MergeFromFieldMask(field_mask) | |
| def MergeFromFieldMask(self, field_mask): | |
| """Merges a FieldMask to the tree.""" | |
| for path in field_mask.paths: | |
| self.AddPath(path) | |
| def AddPath(self, path): | |
| """Adds a field path into the tree. | |
| If the field path to add is a sub-path of an existing field path | |
| in the tree (i.e., a leaf node), it means the tree already matches | |
| the given path so nothing will be added to the tree. If the path | |
| matches an existing non-leaf node in the tree, that non-leaf node | |
| will be turned into a leaf node with all its children removed because | |
| the path matches all the node's children. Otherwise, a new path will | |
| be added. | |
| Args: | |
| path: The field path to add. | |
| """ | |
| node = self._root | |
| for name in path.split('.'): | |
| if name not in node: | |
| node[name] = {} | |
| elif not node[name]: | |
| # Pre-existing empty node implies we already have this entire tree. | |
| return | |
| node = node[name] | |
| # Remove any sub-trees we might have had. | |
| node.clear() | |
| def ToFieldMask(self, field_mask): | |
| """Converts the tree to a FieldMask.""" | |
| field_mask.Clear() | |
| _AddFieldPaths(self._root, '', field_mask) | |
| def IntersectPath(self, path, intersection): | |
| """Calculates the intersection part of a field path with this tree. | |
| Args: | |
| path: The field path to calculates. | |
| intersection: The out tree to record the intersection part. | |
| """ | |
| node = self._root | |
| for name in path.split('.'): | |
| if name not in node: | |
| return | |
| elif not node[name]: | |
| intersection.AddPath(path) | |
| return | |
| node = node[name] | |
| intersection.AddLeafNodes(path, node) | |
| def AddLeafNodes(self, prefix, node): | |
| """Adds leaf nodes begin with prefix to this tree.""" | |
| if not node: | |
| self.AddPath(prefix) | |
| for name in node: | |
| child_path = prefix + '.' + name | |
| self.AddLeafNodes(child_path, node[name]) | |
| def MergeMessage( | |
| self, source, destination, | |
| replace_message, replace_repeated): | |
| """Merge all fields specified by this tree from source to destination.""" | |
| _MergeMessage( | |
| self._root, source, destination, replace_message, replace_repeated) | |
| def _StrConvert(value): | |
| """Converts value to str if it is not.""" | |
| # This file is imported by c extension and some methods like ClearField | |
| # requires string for the field name. py2/py3 has different text | |
| # type and may use unicode. | |
| if not isinstance(value, str): | |
| return value.encode('utf-8') | |
| return value | |
| def _MergeMessage( | |
| node, source, destination, replace_message, replace_repeated): | |
| """Merge all fields specified by a sub-tree from source to destination.""" | |
| source_descriptor = source.DESCRIPTOR | |
| for name in node: | |
| child = node[name] | |
| field = source_descriptor.fields_by_name[name] | |
| if field is None: | |
| raise ValueError('Error: Can\'t find field {0} in message {1}.'.format( | |
| name, source_descriptor.full_name)) | |
| if child: | |
| # Sub-paths are only allowed for singular message fields. | |
| if (field.label == FieldDescriptor.LABEL_REPEATED or | |
| field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE): | |
| raise ValueError('Error: Field {0} in message {1} is not a singular ' | |
| 'message field and cannot have sub-fields.'.format( | |
| name, source_descriptor.full_name)) | |
| if source.HasField(name): | |
| _MergeMessage( | |
| child, getattr(source, name), getattr(destination, name), | |
| replace_message, replace_repeated) | |
| continue | |
| if field.label == FieldDescriptor.LABEL_REPEATED: | |
| if replace_repeated: | |
| destination.ClearField(_StrConvert(name)) | |
| repeated_source = getattr(source, name) | |
| repeated_destination = getattr(destination, name) | |
| repeated_destination.MergeFrom(repeated_source) | |
| else: | |
| if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: | |
| if replace_message: | |
| destination.ClearField(_StrConvert(name)) | |
| if source.HasField(name): | |
| getattr(destination, name).MergeFrom(getattr(source, name)) | |
| else: | |
| setattr(destination, name, getattr(source, name)) | |
| def _AddFieldPaths(node, prefix, field_mask): | |
| """Adds the field paths descended from node to field_mask.""" | |
| if not node and prefix: | |
| field_mask.paths.append(prefix) | |
| return | |
| for name in sorted(node): | |
| if prefix: | |
| child_path = prefix + '.' + name | |
| else: | |
| child_path = name | |
| _AddFieldPaths(node[name], child_path, field_mask) | |