Model#
- class apax.model.gmnn.AtomisticModel(descriptor: ~flax.linen.module.Module = GaussianMomentDescriptor( # attributes radial_fn = RadialFunction( # attributes n_radial = 5 basis_fn = GaussianBasis( # attributes n_basis = 7 r_min = 0.5 r_max = 6.0 dtype = float32 ) n_species = 119 emb_init = 'uniform' dtype = float32 ) n_contr = 8 dtype = float32 apply_mask = True ), readout: ~flax.linen.module.Module = AtomisticReadout( # attributes units = [512, 512] activation_fn = swish b_init = 'normal' dtype = float32 ), scale_shift: ~flax.linen.module.Module = PerElementScaleShift( # attributes n_species = 119 scale = 1.0 shift = 0.0 dtype = float32 ), mask_atoms: bool = True, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Most basic prediction model. Allesmbles descriptor, readout (NNs) and output scale-shifting.
- class apax.model.gmnn.EnergyDerivativeModel(energy_model: ~apax.model.gmnn.EnergyModel = EnergyModel( # attributes atomistic_model = AtomisticModel( # attributes descriptor = GaussianMomentDescriptor( # attributes radial_fn = RadialFunction( # attributes n_radial = 5 basis_fn = GaussianBasis( # attributes n_basis = 7 r_min = 0.5 r_max = 6.0 dtype = float32 ) n_species = 119 emb_init = 'uniform' dtype = float32 ) n_contr = 8 dtype = float32 apply_mask = True ) readout = AtomisticReadout( # attributes units = [512, 512] activation_fn = swish b_init = 'normal' dtype = float32 ) scale_shift = PerElementScaleShift( # attributes n_species = 119 scale = 1.0 shift = 0.0 dtype = float32 ) mask_atoms = True ) corrections = [] init_box = array([0., 0., 0.]) inference_disp_fn = None ), calc_stress: bool = False, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Transforms an EnergyModel into one that also predicts derivatives the total energy. Can calculate forces and stress tensors.
- class apax.model.gmnn.EnergyModel(atomistic_model: ~apax.model.gmnn.AtomisticModel = AtomisticModel( # attributes descriptor = GaussianMomentDescriptor( # attributes radial_fn = RadialFunction( # attributes n_radial = 5 basis_fn = GaussianBasis( # attributes n_basis = 7 r_min = 0.5 r_max = 6.0 dtype = float32 ) n_species = 119 emb_init = 'uniform' dtype = float32 ) n_contr = 8 dtype = float32 apply_mask = True ) readout = AtomisticReadout( # attributes units = [512, 512] activation_fn = swish b_init = 'normal' dtype = float32 ) scale_shift = PerElementScaleShift( # attributes n_species = 119 scale = 1.0 shift = 0.0 dtype = float32 ) mask_atoms = True ), corrections: list[~apax.layers.empirical.EmpiricalEnergyTerm] = <factory>, init_box: ~numpy.array = <factory>, inference_disp_fn: ~typing.Any = None, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Model which post processes the output of an atomistic model and adds empirical energy terms.
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.