Layers

class apax.layers.empirical.EmpiricalEnergyTerm(dtype: Any = <class 'jax.numpy.float32'>, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46866ba5d0>, name: Optional[str] = None)[source]
class apax.layers.empirical.ZBLRepulsion(dtype: Any = <class 'jax.numpy.float32'>, r_max: float = 6.0, apply_mask: bool = True, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46866ba5d0>, name: Optional[str] = None)[source]
setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

apax.layers.initializers.uniform_range(minval, maxval, dtype: ~typing.Any = <class 'jax.numpy.float64'>) Initializer[source]

Builds an initializer that returns real uniformly-distributed random arrays in a specified range.

class apax.layers.ntk_linear.NTKLinear(units: int, w_init: str = 'normal', b_init: str = 'zeros', use_ntk: bool = True, dtype: ~typing.Any = <class 'jax.numpy.float32'>, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Linear layer with activation. Corresponds to an NTK layer with “normal” and “zeros” for w and b initialization and “use_ntk” set to True.

apax.layers.properties.stress_times_vol(energy_fn, position: Array, box, **kwargs) Array[source]

Computes the internal stress of a system multiplied with the box volume. For training purposes.

Parameters:
  • energy_fn – A function that computes the energy of the system. This function must take as an argument perturbation which perturbs the box shape. Any energy function constructed using smap or in energy.py with a standard space will satisfy this property.

  • position – An array of particle positions.

  • box – A box specifying the shape of the simulation volume. Used to infer the volume of the unit cell.

Returns:

A float specifying the stress of the system.

Return type:

Array

class apax.layers.readout.AtomisticReadout(units: List[int] = <factory>, activation_fn: Callable = <function swish at 0x7f46868d3d80>, w_init: str = 'normal', b_init: str = 'zeros', use_ntk: bool = True, n_shallow_ensemble: int = 0, is_feature_fn: bool = False, dtype: Any = <class 'jax.numpy.float32'>, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46866ba5d0>, name: Optional[str] = None)[source]
setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class apax.layers.scaling.PerElementScaleShift(n_species: int = 119, scale: Union[array, float] = 1.0, shift: Union[array, float] = 0.0, dtype: Any = <class 'jax.numpy.float32'>, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46866ba5d0>, name: Optional[str] = None)[source]
setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.