Source code for apax.model.gmnn

import logging
from dataclasses import field
from typing import Any, Callable, Optional, Tuple, Union

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax import Array

from apax.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor
from apax.layers.distances import make_distance_fn
from apax.layers.empirical import EmpiricalEnergyTerm
from apax.layers.masking import mask_by_atom
from apax.layers.properties import stress_times_vol
from apax.layers.readout import AtomisticReadout
from apax.layers.scaling import PerElementScaleShift
from apax.utils.jax_md_reduced import partition
from apax.utils.math import fp64_sum

DisplacementFn = Callable[[Array, Array], Array]
MDModel = Tuple[partition.NeighborFn, Callable, Callable]

log = logging.getLogger(__name__)


[docs] class AtomisticModel(nn.Module): """Most basic prediction model. Allesmbles descriptor, readout (NNs) and output scale-shifting. """ descriptor: nn.Module = GaussianMomentDescriptor() readout: nn.Module = AtomisticReadout() scale_shift: nn.Module = PerElementScaleShift() mask_atoms: bool = True def __call__( self, dr_vec: Array, Z: Array, idx: Array, ) -> Array: gm = self.descriptor(dr_vec, Z, idx) h = jax.vmap(self.readout)(gm) output = self.scale_shift(h, Z) if self.mask_atoms: output = mask_by_atom(output, Z) return output
[docs] class FeatureModel(nn.Module): """Model wrapps some submodel (e.g. a descriptor) to supply distance computation.""" descriptor: nn.Module = GaussianMomentDescriptor() readout: nn.Module = AtomisticReadout() should_average: bool = False init_box: np.array = field(default_factory=lambda: np.array([0.0, 0.0, 0.0])) inference_disp_fn: Any = None mask_atoms: bool = True
[docs] def setup(self): self.compute_distances = make_distance_fn(self.init_box, self.inference_disp_fn)
def __call__( self, R: Array, Z: Array, neighbor: Union[partition.NeighborList, Array], box, offsets, perturbation=None, ): dr_vec, idx = self.compute_distances( R, neighbor, box, offsets, perturbation, ) gm = self.descriptor(dr_vec, Z, idx) features = jax.vmap(self.readout)(gm) if self.mask_atoms: features = mask_by_atom(features, Z) if self.should_average: features = jnp.mean(features, axis=0) return features
[docs] class EnergyModel(nn.Module): """Model which post processes the output of an atomistic model and adds empirical energy terms. """ atomistic_model: AtomisticModel = AtomisticModel() corrections: list[EmpiricalEnergyTerm] = field(default_factory=lambda: []) init_box: np.array = field(default_factory=lambda: np.array([0.0, 0.0, 0.0])) inference_disp_fn: Any = None
[docs] def setup(self): self.compute_distances = make_distance_fn(self.init_box, self.inference_disp_fn)
def __call__( self, R: Array, Z: Array, neighbor: Union[partition.NeighborList, Array], box, offsets, perturbation=None, ): dr_vec, idx = self.compute_distances( R, neighbor, box, offsets, perturbation, ) # Model Core # shape Natoms # shape shallow ens: Natoms x Nensemble atomic_energies = self.atomistic_model(dr_vec, Z, idx) # check for shallow ensemble is_shallow_ensemble = atomic_energies.shape[1] > 1 if is_shallow_ensemble: total_energies_ensemble = fp64_sum(atomic_energies, axis=0) # shape Nensemble result = total_energies_ensemble else: # shape () result = fp64_sum(atomic_energies) # Corrections for correction in self.corrections: energy_correction = correction(dr_vec, Z, idx) result = result + energy_correction # TODO think of nice abstraction for predicting additional properties return result
[docs] class EnergyDerivativeModel(nn.Module): """Transforms an EnergyModel into one that also predicts derivatives the total energy. Can calculate forces and stress tensors. """ # Alternatively, should this be a function transformation? energy_model: EnergyModel = EnergyModel() calc_stress: bool = False def __call__( self, R: Array, Z: Array, neighbor: Union[partition.NeighborList, Array], box, offsets, ): energy, neg_forces = jax.value_and_grad(self.energy_model)( R, Z, neighbor, box, offsets ) forces = -neg_forces prediction = {"energy": energy, "forces": forces} if self.calc_stress: stress = stress_times_vol( self.energy_model, R, box, Z=Z, neighbor=neighbor, offsets=offsets ) prediction["stress"] = stress return prediction
def make_mean_energy_fn(energy_fn): def mean_energy_fn( R: Array, Z: Array, neighbor: Union[partition.NeighborList, Array], box, offsets, perturbation=None, ): e_ens = energy_fn(R, Z, neighbor, box, offsets, perturbation) E_mean = jnp.mean(e_ens) return E_mean return mean_energy_fn def make_single_member_gradient(energy_model, idx): def energy_i_fn(R, Z, neighbor, box, offsets): Ei = energy_model(R, Z, neighbor, box, offsets)[idx] return Ei grad_i_fn = jax.grad(energy_i_fn) return grad_i_fn def make_member_chunk_jac(energy_model, start, end): def energy_chunk_fn(R, Z, neighbor, box, offsets): Ei = energy_model(R, Z, neighbor, box, offsets)[start:end] return Ei grad_i_fn = jax.jacrev(energy_chunk_fn) return grad_i_fn
[docs] class ShallowEnsembleModel(nn.Module): """Transforms an EnergyModel into one that also predicts derivatives the total energy. Can calculate forces and stress tensors. """ energy_model: EnergyModel = EnergyModel() calc_stress: bool = False force_variance: bool = True chunk_size: Optional[int] = None def __call__( self, R: Array, Z: Array, neighbor: Union[partition.NeighborList, Array], box, offsets, ): energy_ens = self.energy_model(R, Z, neighbor, box, offsets) mean_energy_fn = make_mean_energy_fn(self.energy_model) n_ens = energy_ens.shape[0] divisor = 1 / (n_ens - 1) energy_mean = jnp.mean(energy_ens) energy_variance = divisor * fp64_sum((energy_ens - energy_mean) ** 2) prediction = { "energy": energy_mean, "energy_ensemble": energy_ens, "energy_uncertainty": jnp.sqrt(energy_variance), } if self.force_variance: if not self.chunk_size: forces_ens = -jax.jacrev(self.energy_model)(R, Z, neighbor, box, offsets) else: with jax.ensure_compile_time_eval(): if not n_ens % self.chunk_size == 0: m = "the chunksize needs to be a factor of the number of ensemble memebrs" raise ValueError(m) forces_ens = [] start = 0 for _ in range(n_ens // self.chunk_size): end = start + self.chunk_size jac_i_fn = make_member_chunk_jac(self.energy_model, start, end) force_i = -jac_i_fn(R, Z, neighbor, box, offsets) forces_ens.append(force_i) start = end n_atoms = R.shape[0] forces_ens = jnp.array(forces_ens) forces_ens = np.reshape(forces_ens, (n_ens, n_atoms, 3)) forces_mean = jnp.mean(forces_ens, axis=0) forces_variance = divisor * fp64_sum((forces_ens - forces_mean) ** 2, axis=0) prediction["forces"] = forces_mean prediction["forces_uncertainty"] = jnp.sqrt(forces_variance) prediction["forces_ensemble"] = forces_ens else: forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets) prediction["forces"] = forces_mean if self.calc_stress: stress = stress_times_vol( mean_energy_fn, R, box, Z=Z, neighbor=neighbor, offsets=offsets ) prediction["stress"] = stress return prediction