Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	File size: 2,487 Bytes
			
			| e21f690 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | import inspect
import platform
from .registry import PLUGIN_LAYERS
if platform.system() == 'Windows':
    import regex as re
else:
    import re
def infer_abbr(class_type):
    """Infer abbreviation from the class name.
    This method will infer the abbreviation to map class types to
    abbreviations.
    Rule 1: If the class has the property "abbr", return the property.
    Rule 2: Otherwise, the abbreviation falls back to snake case of class
    name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
    Args:
        class_type (type): The norm layer type.
    Returns:
        str: The inferred abbreviation.
    """
    def camel2snack(word):
        """Convert camel case word into snack case.
        Modified from `inflection lib
        <https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
        Example::
            >>> camel2snack("FancyBlock")
            'fancy_block'
        """
        word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
        word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
        word = word.replace('-', '_')
        return word.lower()
    if not inspect.isclass(class_type):
        raise TypeError(
            f'class_type must be a type, but got {type(class_type)}')
    if hasattr(class_type, '_abbr_'):
        return class_type._abbr_
    else:
        return camel2snack(class_type.__name__)
def build_plugin_layer(cfg, postfix='', **kwargs):
    """Build plugin layer.
    Args:
        cfg (None or dict): cfg should contain:
            type (str): identify plugin layer type.
            layer args: args needed to instantiate a plugin layer.
        postfix (int, str): appended into norm abbreviation to
            create named layer. Default: ''.
    Returns:
        tuple[str, nn.Module]:
            name (str): abbreviation + postfix
            layer (nn.Module): created plugin layer
    """
    if not isinstance(cfg, dict):
        raise TypeError('cfg must be a dict')
    if 'type' not in cfg:
        raise KeyError('the cfg dict must contain the key "type"')
    cfg_ = cfg.copy()
    layer_type = cfg_.pop('type')
    if layer_type not in PLUGIN_LAYERS:
        raise KeyError(f'Unrecognized plugin type {layer_type}')
    plugin_layer = PLUGIN_LAYERS.get(layer_type)
    abbr = infer_abbr(plugin_layer)
    assert isinstance(postfix, (int, str))
    name = abbr + str(postfix)
    layer = plugin_layer(**kwargs, **cfg_)
    return name, layer
 | 
