Source code for pugh_torch.modules.activation

""" Easy interface for swapping out activation functions, especially those
that may have different weight initialization methods.

    * weights  <- initialization depends on activation function <----
    * normalization                                                 |
    * activation <---------------------------------------------------

To create a new activation, do the following:
    * Inherit from ActivationModule to register
    * [optional] implement ``init_layer`` method
    * [optional] implement ``init_first_layer`` method
One this is done, your activation function will me available as:
    Activation("myactivationlowercase", **kwargs)

"""

import torch
from torch import nn
import numpy as np
from packaging import version
from math import sqrt
from . import init as wi

_torch_version = version.parse(torch.__version__)

_activation_lookup = {}


[docs]class ActivationModule(nn.Module): """Only used to automatically register activation functions.""" def __init_subclass__(cls, **kwargs): """Automatic registration stuff""" super().__init_subclass__(**kwargs) name = cls.__name__.lower() if name in _activation_lookup: raise ValueError( f'Activation function "{name}" already exists: {_activation_lookup[name]}' ) _activation_lookup[name] = cls
[docs] @torch.no_grad() def init_layer(self, m): """ Override this in child activation function """ pass
[docs] @torch.no_grad() def init_first_layer(self, m): """ Override this in child activation function """ return self.init_layer(m)
[docs]def Activation(name, init_layers=None, *, first=False, **kwargs): """Activation Factory Function Parameters ---------- name : str Activation function type init_layers : nn.Module or list of nn.Module Weights that need initialization based on kwargs : dict Passed along to activation function constructor. """ name = name.lower() if init_layers is not None: if isinstance(init_layers, nn.Module): init_layers = [ init_layers, ] assert isinstance(init_layers, list) activation_obj = _activation_lookup[name](**kwargs) if init_layers: for init_layer in init_layers: if first: init_layer.apply(activation_obj.init_first_layer) else: init_layer.apply(activation_obj.init_layer) return activation_obj
[docs]class Noop(ActivationModule):
[docs] def forward(self, input): return input
[docs]class Sine(ActivationModule): """ Implicit Neural Representations with Periodic Activation Functions https://arxiv.org/pdf/2006.09661.pdf """ def __init__(self, frequency=30): super().__init__() # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 self.frequency = frequency
[docs] def forward(self, input): return torch.sin(self.frequency * input)
[docs] @torch.no_grad() def init_layer(self, m): if hasattr(m, "weight") and m.weight is not None: num_input = m.weight.size(-1) m.weight.uniform_( -sqrt(6 / num_input) / self.frequency, sqrt(6 / num_input) / self.frequency, ) if hasattr(m, "bias") and m.bias is not None: m.bias /= self.frequency / 2
[docs] @torch.no_grad() def init_first_layer(self, m): if hasattr(m, "weight") and m.weight is not None: num_input = m.weight.size(-1) m.weight.normal_(std=1 / num_input) if hasattr(m, "bias") and m.bias is not None: m.bias /= self.frequency / 2
################################# # torch.nn activation functions # #################################
[docs]class ELU(nn.ELU, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): if hasattr(m, "weight"): num_input = m.weight.size(-1) nn.init.normal_( m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input) )
[docs]class Hardshrink(nn.Hardshrink, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.xavier(m)
if _torch_version >= version.parse("1.5.0"):
[docs] class Hardsigmoid(nn.Hardsigmoid, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.xavier(m)
[docs]class Hardtanh(nn.Hardtanh, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.xavier(m)
if _torch_version >= version.parse("1.6.0"):
[docs] class Hardswish(nn.Hardswish, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.he(m, nonlinearity="relu")
[docs]class LeakyReLU(nn.LeakyReLU, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.he(m, nonlinearity="leaky_relu")
[docs]class LogSigmoid(nn.LogSigmoid, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.xavier(m)
[docs]class MultiheadAttention(nn.MultiheadAttention, ActivationModule): pass
[docs]class PReLU(nn.PReLU, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.he(m, nonlinearity="leaky_relu")
[docs]class ReLU(nn.ReLU, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.he(m, nonlinearity="relu")
[docs]class ReLU6(nn.ReLU6, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.he(m, nonlinearity="relu")
[docs]class RReLU(nn.RReLU, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.he(m, nonlinearity="leaky_relu")
[docs]class SELU(nn.SELU, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): if hasattr(m, "weight"): num_input = m.weight.size(-1) nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))
[docs]class CELU(nn.CELU, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): if hasattr(m, "weight"): num_input = m.weight.size(-1) nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))
[docs]class GELU(nn.GELU, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.he(m, nonlinearity="relu")
[docs]class Sigmoid(nn.Sigmoid, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.xavier(m)
[docs]class Softplus(nn.Softplus, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.he(m, nonlinearity="relu")
[docs]class Softshrink(nn.Softshrink, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.xavier(m)
[docs]class Softsign(nn.Softsign, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.xavier(m)
[docs]class Tanh(nn.Tanh, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.xavier(m, gain=nn.init.calculate_gain("tanh"))
[docs]class Tanhshrink(nn.Tanhshrink, ActivationModule):
[docs] @torch.no_grad() def init_layer(self, m): wi.xavier(m)
[docs]class Threshold(nn.Threshold, ActivationModule): pass