Layers¶
- apax.layers.activation.get_activation_fn(name: str) Callable[[float], float][source]¶
Get the activation function from jax.nn. Also performs a bunch of checks to make sure that it is a valid activation function.
- Parameters:
name (str) – name of the activation function in jax.nn.
- Returns:
- Activation function, that takes in a float
and returns a float.
- Return type:
activation_fn (Callable)
- class apax.layers.empirical.EmpiricalEnergyTerm(dtype: Any = <class 'jax.numpy.float32'>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]=<flax.linen.module._Sentinel object at 0x749ff3dc2490>, name: str | None = None)[source]¶
- dtype¶
alias of
float32
- class apax.layers.empirical.ExponentialRepulsion(dtype: Any = <class 'jax.numpy.float32'>, r_max: float = 2.0, apply_mask: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]=<flax.linen.module._Sentinel object at 0x749ff3dc2490>, name: str | None = None)[source]¶
- setup()[source]¶
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- class apax.layers.empirical.LatentEwald(dtype: Any = <class 'jax.numpy.float32'>, kgrid: list[int] = <factory>, sigma: float = 1.0, apply_mask: bool = True, use_property: str = 'charges', parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]¶
Latent Ewald summation by Cheng https://arxiv.org/abs/2408.15165 Requires a property head which predicts ‘charge’ per atom.
- class apax.layers.empirical.ZBLRepulsion(dtype: Any = <class 'jax.numpy.float32'>, r_max: float = 2.0, apply_mask: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]=<flax.linen.module._Sentinel object at 0x749ff3dc2490>, name: str | None = None)[source]¶
- setup()[source]¶
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- apax.layers.initializers.uniform_range(minval, maxval, dtype: Any = <class 'jax.numpy.float64'>) Initializer[source]¶
Builds an initializer that returns real uniformly-distributed random arrays in a specified range.
- class apax.layers.ntk_linear.NTKLinear(units: int, w_init: str = 'normal', b_init: str = 'zeros', use_ntk: bool = True, dtype: str = 'fp32', parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]¶
Linear layer with activation. Corresponds to an NTK layer with “normal” and “zeros” for w and b initialization and “use_ntk” set to True.
- class apax.layers.properties.PropertyHead(pname: str, readout: Module = AtomisticReadout( # attributes units = [32, 32] activation_fn=silu w_init = 'normal' b_init = 'zeros' use_ntk = True n_shallow_ensemble = 0 is_feature_fn = False dtype = float32 ), aggregation: str = 'none', mode: str = 'l0', apply_mask: bool = True, parent: Module | Scope | _Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]¶
the readout is currently limited to a single number
- setup()[source]¶
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- apax.layers.properties.stress_times_vol(energy_fn, position: Array, box, **kwargs) Array[source]¶
Computes the internal stress of a system multiplied with the box volume. For training purposes.
- Parameters:
energy_fn – A function that computes the energy of the system. This function must take as an argument perturbation which perturbs the box shape. Any energy function constructed using smap or in energy.py with a standard space will satisfy this property.
position – An array of particle positions.
box – A box specifying the shape of the simulation volume. Used to infer the volume of the unit cell.
- Returns:
A float specifying the stress of the system.
- Return type:
Array
- class apax.layers.readout.AtomisticReadout(units: List[int] = <factory>, activation_fn: Callable = <PjitFunction of <function silu at 0x74a054fa32e0>>, w_init: str = 'normal', b_init: str = 'zeros', use_ntk: bool = True, n_shallow_ensemble: int = 0, is_feature_fn: bool = False, dtype: Any = <class 'jax.numpy.float32'>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]=<flax.linen.module._Sentinel object at 0x749ff3dc2490>, name: str | None = None)[source]¶
- activation_fn() Array¶
SiLU (aka swish) activation function.
Computes the element-wise function:
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]swish()andsilu()are both aliases for the same function.- Parameters:
x – input array
- Returns:
An array.
See also
sigmoid()
- dtype¶
alias of
float32
- setup()[source]¶
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- class apax.layers.scaling.PerElementScaleShift(n_species: int = 119, scale: Union[array, float]=1.0, shift: Union[array, float]=0.0, dtype: Any = <class 'jax.numpy.float32'>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]=<flax.linen.module._Sentinel object at 0x749ff3dc2490>, name: str | None = None)[source]¶
- dtype¶
alias of
float32
- setup()[source]¶
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.