Source code for apax.layers.activation

import inspect
from typing import Callable

import jax


def variance_preserving_swish(x) -> float:
    out = 1.6765324703310907 * jax.nn.swish(x)
    return out


[docs] def get_activation_fn(name: str) -> Callable[[float], float]: """Get the activation function from jax.nn. Also performs a bunch of checks to make sure that it is a valid activation function. Args: name (str): name of the activation function in jax.nn. Returns: activation_fn (Callable): Activation function, that takes in a float and returns a float. """ if name == "variance_preserving_swish": # Keep backwards compatibility return variance_preserving_swish if not hasattr(jax.nn, name): raise AttributeError( f"jax.nn has no attribute {name}, see https://docs.jax.dev/en/latest/jax.nn.html for options." ) activation_fn = getattr(jax.nn, name) if not callable(activation_fn): raise TypeError(f"jax.nn.{name} is not callable") signature = inspect.signature(activation_fn) required_positional = [ p for p in signature.parameters.values() if ( p.kind in [ inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ] ) and (p.default is inspect.Parameter.empty) ] if len(required_positional) != 1: raise TypeError( f"jax.nn.{name} is not a valid readout activation: expected exactly one required positional argument, but needs {len(required_positional)}." ) return activation_fn