Layers

apax.layers.activation.get_activation_fn(name: str) Callable[[float], float][source]

Get the activation function from jax.nn. Also performs a bunch of checks to make sure that it is a valid activation function.

Parameters:

name (str) – name of the activation function in jax.nn.

Returns:

Activation function, that takes in a float

and returns a float.

Return type:

activation_fn (Callable)

class apax.layers.empirical.EmpiricalEnergyTerm(dtype: Any = <class 'jax.numpy.float32'>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]=<flax.linen.module._Sentinel object at 0x749ff3dc2490>, name: str | None = None)[source]
dtype

alias of float32

class apax.layers.empirical.ExponentialRepulsion(dtype: Any = <class 'jax.numpy.float32'>, r_max: float = 2.0, apply_mask: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]=<flax.linen.module._Sentinel object at 0x749ff3dc2490>, name: str | None = None)[source]
setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class apax.layers.empirical.LatentEwald(dtype: Any = <class 'jax.numpy.float32'>, kgrid: list[int] = <factory>, sigma: float = 1.0, apply_mask: bool = True, use_property: str = 'charges', parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Latent Ewald summation by Cheng https://arxiv.org/abs/2408.15165 Requires a property head which predicts ‘charge’ per atom.

class apax.layers.empirical.ZBLRepulsion(dtype: Any = <class 'jax.numpy.float32'>, r_max: float = 2.0, apply_mask: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]=<flax.linen.module._Sentinel object at 0x749ff3dc2490>, name: str | None = None)[source]
setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

apax.layers.initializers.uniform_range(minval, maxval, dtype: Any = <class 'jax.numpy.float64'>) Initializer[source]

Builds an initializer that returns real uniformly-distributed random arrays in a specified range.

class apax.layers.ntk_linear.NTKLinear(units: int, w_init: str = 'normal', b_init: str = 'zeros', use_ntk: bool = True, dtype: str = 'fp32', parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Linear layer with activation. Corresponds to an NTK layer with “normal” and “zeros” for w and b initialization and “use_ntk” set to True.

class apax.layers.properties.PropertyHead(pname: str, readout: Module = AtomisticReadout(     # attributes     units = [32, 32]     activation_fn=silu     w_init = 'normal'     b_init = 'zeros'     use_ntk = True     n_shallow_ensemble = 0     is_feature_fn = False     dtype = float32 ), aggregation: str = 'none', mode: str = 'l0', apply_mask: bool = True, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

the readout is currently limited to a single number

setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

apax.layers.properties.stress_times_vol(energy_fn, position: Array, box, **kwargs) Array[source]

Computes the internal stress of a system multiplied with the box volume. For training purposes.

Parameters:
  • energy_fn – A function that computes the energy of the system. This function must take as an argument perturbation which perturbs the box shape. Any energy function constructed using smap or in energy.py with a standard space will satisfy this property.

  • position – An array of particle positions.

  • box – A box specifying the shape of the simulation volume. Used to infer the volume of the unit cell.

Returns:

A float specifying the stress of the system.

Return type:

Array

class apax.layers.readout.AtomisticReadout(units: List[int] = <factory>, activation_fn: Callable = <PjitFunction of <function silu at 0x74a054fa32e0>>, w_init: str = 'normal', b_init: str = 'zeros', use_ntk: bool = True, n_shallow_ensemble: int = 0, is_feature_fn: bool = False, dtype: Any = <class 'jax.numpy.float32'>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]=<flax.linen.module._Sentinel object at 0x749ff3dc2490>, name: str | None = None)[source]
activation_fn() Array

SiLU (aka swish) activation function.

Computes the element-wise function:

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]

swish() and silu() are both aliases for the same function.

Parameters:

x – input array

Returns:

An array.

See also

sigmoid()

dtype

alias of float32

setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class apax.layers.scaling.PerElementScaleShift(n_species: int = 119, scale: Union[array, float]=1.0, shift: Union[array, float]=0.0, dtype: Any = <class 'jax.numpy.float32'>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]=<flax.linen.module._Sentinel object at 0x749ff3dc2490>, name: str | None = None)[source]
dtype

alias of float32

setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.