import torch from torch import nn def linear(): return nn.Identity() def relu(): return nn.ReLU() def prelu(): return nn.PReLU() def leaky_relu(): return nn.LeakyReLU() def sigmoid(): return nn.Sigmoid() def softmax(dim=None): return nn.Softmax(dim=dim) def tanh(): return nn.Tanh() def gelu(): return nn.GELU() def register_activation(custom_act): if ( custom_act.__name__ in globals().keys() or custom_act.__name__.lower() in globals().keys() ): raise ValueError( f"Activation {custom_act.__name__} already exists. Choose another name." ) globals().update({custom_act.__name__: custom_act}) def get(identifier): if identifier is None: return None elif callable(identifier): return identifier elif isinstance(identifier, str): cls = globals().get(identifier) if cls is None: raise ValueError( "Could not interpret activation identifier: " + str(identifier) ) return cls else: raise ValueError( "Could not interpret activation identifier: " + str(identifier) ) if __name__ == "__main__": print(globals().keys()) print(globals().get("tanh"))