""" Easy interface for swapping out activation functions, especially those
that may have different weight initialization methods.
* weights <- initialization depends on activation function <----
* normalization |
* activation <---------------------------------------------------
To create a new activation, do the following:
* Inherit from ActivationModule to register
* [optional] implement ``init_layer`` method
* [optional] implement ``init_first_layer`` method
One this is done, your activation function will me available as:
Activation("myactivationlowercase", **kwargs)
"""
import torch
from torch import nn
import numpy as np
from packaging import version
from math import sqrt
from . import init as wi
_torch_version = version.parse(torch.__version__)
_activation_lookup = {}
[docs]class ActivationModule(nn.Module):
"""Only used to automatically register activation functions."""
def __init_subclass__(cls, **kwargs):
"""Automatic registration stuff"""
super().__init_subclass__(**kwargs)
name = cls.__name__.lower()
if name in _activation_lookup:
raise ValueError(
f'Activation function "{name}" already exists: {_activation_lookup[name]}'
)
_activation_lookup[name] = cls
[docs] @torch.no_grad()
def init_layer(self, m):
"""
Override this in child activation function
"""
pass
[docs] @torch.no_grad()
def init_first_layer(self, m):
"""
Override this in child activation function
"""
return self.init_layer(m)
[docs]def Activation(name, init_layers=None, *, first=False, **kwargs):
"""Activation Factory Function
Parameters
----------
name : str
Activation function type
init_layers : nn.Module or list of nn.Module
Weights that need initialization based on
kwargs : dict
Passed along to activation function constructor.
"""
name = name.lower()
if init_layers is not None:
if isinstance(init_layers, nn.Module):
init_layers = [
init_layers,
]
assert isinstance(init_layers, list)
activation_obj = _activation_lookup[name](**kwargs)
if init_layers:
for init_layer in init_layers:
if first:
init_layer.apply(activation_obj.init_first_layer)
else:
init_layer.apply(activation_obj.init_layer)
return activation_obj
[docs]class Noop(ActivationModule):
[docs] def forward(self, input):
return input
[docs]class Sine(ActivationModule):
"""
Implicit Neural Representations with Periodic Activation Functions
https://arxiv.org/pdf/2006.09661.pdf
"""
def __init__(self, frequency=30):
super().__init__()
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
self.frequency = frequency
[docs] def forward(self, input):
return torch.sin(self.frequency * input)
[docs] @torch.no_grad()
def init_layer(self, m):
if hasattr(m, "weight") and m.weight is not None:
num_input = m.weight.size(-1)
m.weight.uniform_(
-sqrt(6 / num_input) / self.frequency,
sqrt(6 / num_input) / self.frequency,
)
if hasattr(m, "bias") and m.bias is not None:
m.bias /= self.frequency / 2
[docs] @torch.no_grad()
def init_first_layer(self, m):
if hasattr(m, "weight") and m.weight is not None:
num_input = m.weight.size(-1)
m.weight.normal_(std=1 / num_input)
if hasattr(m, "bias") and m.bias is not None:
m.bias /= self.frequency / 2
#################################
# torch.nn activation functions #
#################################
[docs]class ELU(nn.ELU, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
if hasattr(m, "weight"):
num_input = m.weight.size(-1)
nn.init.normal_(
m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input)
)
[docs]class Hardshrink(nn.Hardshrink, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.xavier(m)
if _torch_version >= version.parse("1.5.0"):
[docs] class Hardsigmoid(nn.Hardsigmoid, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.xavier(m)
[docs]class Hardtanh(nn.Hardtanh, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.xavier(m)
if _torch_version >= version.parse("1.6.0"):
[docs] class Hardswish(nn.Hardswish, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.he(m, nonlinearity="relu")
[docs]class LeakyReLU(nn.LeakyReLU, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.he(m, nonlinearity="leaky_relu")
[docs]class LogSigmoid(nn.LogSigmoid, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.xavier(m)
[docs]class MultiheadAttention(nn.MultiheadAttention, ActivationModule):
pass
[docs]class PReLU(nn.PReLU, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.he(m, nonlinearity="leaky_relu")
[docs]class ReLU(nn.ReLU, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.he(m, nonlinearity="relu")
[docs]class ReLU6(nn.ReLU6, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.he(m, nonlinearity="relu")
[docs]class RReLU(nn.RReLU, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.he(m, nonlinearity="leaky_relu")
[docs]class SELU(nn.SELU, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
if hasattr(m, "weight"):
num_input = m.weight.size(-1)
nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))
[docs]class CELU(nn.CELU, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
if hasattr(m, "weight"):
num_input = m.weight.size(-1)
nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))
[docs]class GELU(nn.GELU, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.he(m, nonlinearity="relu")
[docs]class Sigmoid(nn.Sigmoid, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.xavier(m)
[docs]class Softplus(nn.Softplus, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.he(m, nonlinearity="relu")
[docs]class Softshrink(nn.Softshrink, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.xavier(m)
[docs]class Softsign(nn.Softsign, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.xavier(m)
[docs]class Tanh(nn.Tanh, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.xavier(m, gain=nn.init.calculate_gain("tanh"))
[docs]class Tanhshrink(nn.Tanhshrink, ActivationModule):
[docs] @torch.no_grad()
def init_layer(self, m):
wi.xavier(m)
[docs]class Threshold(nn.Threshold, ActivationModule):
pass