Model

class apax.model.gmnn.AtomisticModel(descriptor: ~flax.linen.module.Module = GaussianMomentDescriptor(     # attributes     radial_fn = RadialFunction(         # attributes         n_radial = 5         basis_fn = GaussianBasis(             # attributes             n_basis = 7             r_min = 0.5             r_max = 6.0             dtype = float32         )         n_species = 119         emb_init = 'uniform'         use_embed_norm = True         one_sided_dist = False         dtype = float32     )     n_contr = 8     dtype = float32     apply_mask = True ), readout: ~flax.linen.module.Module = AtomisticReadout(     # attributes     units = [512, 512]     activation_fn = swish     w_init = 'normal'     b_init = 'zeros'     use_ntk = True     n_shallow_ensemble = 0     is_feature_fn = False     dtype = float32 ), scale_shift: ~flax.linen.module.Module = PerElementScaleShift(     # attributes     n_species = 119     scale = 1.0     shift = 0.0     dtype = float32 ), mask_atoms: bool = True, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Most basic prediction model. Allesmbles descriptor, readout (NNs) and output scale-shifting.

class apax.model.gmnn.EnergyDerivativeModel(energy_model: ~apax.model.gmnn.EnergyModel = EnergyModel(     # attributes     atomistic_model = AtomisticModel(         # attributes         descriptor = GaussianMomentDescriptor(             # attributes             radial_fn = RadialFunction(                 # attributes                 n_radial = 5                 basis_fn = GaussianBasis(                     # attributes                     n_basis = 7                     r_min = 0.5                     r_max = 6.0                     dtype = float32                 )                 n_species = 119                 emb_init = 'uniform'                 use_embed_norm = True                 one_sided_dist = False                 dtype = float32             )             n_contr = 8             dtype = float32             apply_mask = True         )         readout = AtomisticReadout(             # attributes             units = [512, 512]             activation_fn = swish             w_init = 'normal'             b_init = 'zeros'             use_ntk = True             n_shallow_ensemble = 0             is_feature_fn = False             dtype = float32         )         scale_shift = PerElementScaleShift(             # attributes             n_species = 119             scale = 1.0             shift = 0.0             dtype = float32         )         mask_atoms = True     )     corrections = []     init_box = array([0., 0., 0.])     inference_disp_fn = None ), calc_stress: bool = False, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Transforms an EnergyModel into one that also predicts derivatives the total energy. Can calculate forces and stress tensors.

class apax.model.gmnn.EnergyModel(atomistic_model: ~apax.model.gmnn.AtomisticModel = AtomisticModel(     # attributes     descriptor = GaussianMomentDescriptor(         # attributes         radial_fn = RadialFunction(             # attributes             n_radial = 5             basis_fn = GaussianBasis(                 # attributes                 n_basis = 7                 r_min = 0.5                 r_max = 6.0                 dtype = float32             )             n_species = 119             emb_init = 'uniform'             use_embed_norm = True             one_sided_dist = False             dtype = float32         )         n_contr = 8         dtype = float32         apply_mask = True     )     readout = AtomisticReadout(         # attributes         units = [512, 512]         activation_fn = swish         w_init = 'normal'         b_init = 'zeros'         use_ntk = True         n_shallow_ensemble = 0         is_feature_fn = False         dtype = float32     )     scale_shift = PerElementScaleShift(         # attributes         n_species = 119         scale = 1.0         shift = 0.0         dtype = float32     )     mask_atoms = True ), corrections: list[~apax.layers.empirical.EmpiricalEnergyTerm] = <factory>, init_box: ~numpy.array = <factory>, inference_disp_fn: ~typing.Any = None, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Model which post processes the output of an atomistic model and adds empirical energy terms.

setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class apax.model.gmnn.FeatureModel(descriptor: ~flax.linen.module.Module = GaussianMomentDescriptor(     # attributes     radial_fn = RadialFunction(         # attributes         n_radial = 5         basis_fn = GaussianBasis(             # attributes             n_basis = 7             r_min = 0.5             r_max = 6.0             dtype = float32         )         n_species = 119         emb_init = 'uniform'         use_embed_norm = True         one_sided_dist = False         dtype = float32     )     n_contr = 8     dtype = float32     apply_mask = True ), readout: ~flax.linen.module.Module = AtomisticReadout(     # attributes     units = [512, 512]     activation_fn = swish     w_init = 'normal'     b_init = 'zeros'     use_ntk = True     n_shallow_ensemble = 0     is_feature_fn = False     dtype = float32 ), should_average: bool = False, init_box: ~numpy.array = <factory>, inference_disp_fn: ~typing.Any = None, mask_atoms: bool = True, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Model wrapps some submodel (e.g. a descriptor) to supply distance computation.

setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class apax.model.gmnn.ShallowEnsembleModel(energy_model: ~apax.model.gmnn.EnergyModel = EnergyModel(     # attributes     atomistic_model = AtomisticModel(         # attributes         descriptor = GaussianMomentDescriptor(             # attributes             radial_fn = RadialFunction(                 # attributes                 n_radial = 5                 basis_fn = GaussianBasis(                     # attributes                     n_basis = 7                     r_min = 0.5                     r_max = 6.0                     dtype = float32                 )                 n_species = 119                 emb_init = 'uniform'                 use_embed_norm = True                 one_sided_dist = False                 dtype = float32             )             n_contr = 8             dtype = float32             apply_mask = True         )         readout = AtomisticReadout(             # attributes             units = [512, 512]             activation_fn = swish             w_init = 'normal'             b_init = 'zeros'             use_ntk = True             n_shallow_ensemble = 0             is_feature_fn = False             dtype = float32         )         scale_shift = PerElementScaleShift(             # attributes             n_species = 119             scale = 1.0             shift = 0.0             dtype = float32         )         mask_atoms = True     )     corrections = []     init_box = array([0., 0., 0.])     inference_disp_fn = None ), calc_stress: bool = False, force_variance: bool = True, chunk_size: int | None = None, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Transforms an EnergyModel into one that also predicts derivatives the total energy. Can calculate forces and stress tensors.