import logging
from dataclasses import field
from typing import Any, Callable, Optional, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax import Array
from apax.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor
from apax.layers.distances import make_distance_fn
from apax.layers.empirical import EmpiricalEnergyTerm
from apax.layers.masking import mask_by_atom
from apax.layers.properties import stress_times_vol
from apax.layers.readout import AtomisticReadout
from apax.layers.scaling import PerElementScaleShift
from apax.utils.jax_md_reduced import partition
from apax.utils.math import fp64_sum
from apax.utils.transform import make_energy_only_model
DisplacementFn = Callable[[Array, Array], Array]
MDModel = Tuple[partition.NeighborFn, Callable, Callable]
log = logging.getLogger(__name__)
[docs]
class FeatureModel(nn.Module):
"""Model wraps some submodel (e.g. a descriptor) to supply distance computation."""
representation: nn.Module = GaussianMomentDescriptor()
readout: nn.Module = AtomisticReadout()
should_average: bool = False
init_box: np.array = field(default_factory=lambda: np.array([0.0, 0.0, 0.0]))
inference_disp_fn: Any = None
mask_atoms: bool = True
[docs]
def setup(self):
self.compute_distances = make_distance_fn(self.init_box, self.inference_disp_fn)
def __call__(
self,
R: Array,
Z: Array,
neighbor: Union[partition.NeighborList, Array],
box,
offsets,
perturbation=None,
):
dr_vec, idx = self.compute_distances(
R,
neighbor,
box,
offsets,
perturbation,
)
features = self.representation(dr_vec, Z, idx)
if self.readout:
features = jax.vmap(self.readout)(features)
if self.mask_atoms:
features = mask_by_atom(features, Z)
if self.should_average:
features = jnp.mean(features, axis=0)
return features
[docs]
class EnergyModel(nn.Module):
"""Model which post processes the output of an atomistic model and
adds empirical energy terms.
"""
representation: nn.Module = GaussianMomentDescriptor()
readout: nn.Module = AtomisticReadout()
scale_shift: nn.Module = PerElementScaleShift()
property_heads: list[nn.Module] = field(default_factory=lambda: [])
corrections: list[EmpiricalEnergyTerm] = field(default_factory=lambda: [])
init_box: np.array = field(default_factory=lambda: np.array([0.0, 0.0, 0.0]))
mask_atoms: bool = True
inference_disp_fn: Any = None
[docs]
def setup(self):
self.compute_distances = make_distance_fn(self.init_box, self.inference_disp_fn)
def __call__(
self,
R: Array,
Z: Array,
neighbor: Union[partition.NeighborList, Array],
box,
offsets,
perturbation=None,
):
dr_vec, idx = self.compute_distances(
R,
neighbor,
box,
offsets,
perturbation,
)
# Model Core
# shape Natoms
# shape shallow ens: Natoms x Nensemble
g = self.representation(dr_vec, Z, idx)
h = jax.vmap(self.readout)(g)
E_i = self.scale_shift(h, Z)
if self.mask_atoms:
E_i = mask_by_atom(E_i, Z)
# check for shallow ensemble
is_shallow_ensemble = E_i.shape[1] > 1
if is_shallow_ensemble: # is this necessary or is using sum with axis=0 enough?
total_energies_ensemble = fp64_sum(E_i, axis=0)
# shape Nensemble
energy = total_energies_ensemble
else:
# shape ()
energy = fp64_sum(E_i)
properties = {}
for property_head in self.property_heads:
result = property_head(g, R, dr_vec, Z, idx, box)
properties.update(result)
# Corrections
for correction in self.corrections:
energy_correction = correction(R, dr_vec, Z, idx, box, properties)
energy = energy + energy_correction
return energy, properties
[docs]
class EnergyDerivativeModel(nn.Module):
"""Transforms an EnergyModel into one that also predicts derivatives the total energy.
Can calculate forces and stress tensors.
"""
# Alternatively, should this be a function transformation?
energy_model: EnergyModel = EnergyModel()
calc_stress: bool = False
def __call__(
self,
R: Array,
Z: Array,
neighbor: Union[partition.NeighborList, Array],
box,
offsets,
):
ef_function = jax.value_and_grad(self.energy_model, has_aux=True)
(energy, properties), neg_forces = ef_function(R, Z, neighbor, box, offsets)
forces = -neg_forces
prediction = {"energy": energy, "forces": forces}
prediction.update(properties)
if self.calc_stress:
stress = stress_times_vol(
make_energy_only_model(self.energy_model),
R,
box,
Z=Z,
neighbor=neighbor,
offsets=offsets,
)
prediction["stress"] = stress
return prediction
def make_mean_energy_fn(energy_fn):
def mean_energy_fn(
R: Array,
Z: Array,
neighbor: Union[partition.NeighborList, Array],
box,
offsets,
perturbation=None,
):
e_ens, _ = energy_fn(R, Z, neighbor, box, offsets, perturbation)
E_mean = jnp.mean(e_ens)
return E_mean
return mean_energy_fn
def make_member_chunk_jac(energy_model, start, end):
def energy_chunk_fn(R, Z, neighbor, box, offsets):
Ei = energy_model(R, Z, neighbor, box, offsets)[start:end]
return Ei
grad_i_fn = jax.jacrev(energy_chunk_fn)
return grad_i_fn
[docs]
class ShallowEnsembleModel(nn.Module):
"""Transforms an EnergyModel into one that also predicts derivatives the total energy.
Can calculate forces and stress tensors.
"""
energy_model: EnergyModel = EnergyModel()
calc_stress: bool = False
force_variance: bool = True
chunk_size: Optional[int] = None
def __call__(
self,
R: Array,
Z: Array,
neighbor: Union[partition.NeighborList, Array],
box,
offsets,
):
energy_ens, properties = self.energy_model(R, Z, neighbor, box, offsets)
# The two functions below drop the calculation of properties
mean_energy_fn = make_mean_energy_fn(self.energy_model)
energy_fn = make_energy_only_model(self.energy_model)
n_ens = energy_ens.shape[0]
divisor = 1 / (n_ens - 1)
energy_mean = jnp.mean(energy_ens)
energy_variance = divisor * fp64_sum((energy_ens - energy_mean) ** 2)
prediction = {
"energy": energy_mean,
"energy_ensemble": energy_ens,
"energy_uncertainty": jnp.sqrt(energy_variance),
}
prediction.update(properties)
if self.force_variance:
if not self.chunk_size:
forces_ens = -jax.jacrev(energy_fn)(R, Z, neighbor, box, offsets)
else:
with jax.ensure_compile_time_eval():
if not n_ens % self.chunk_size == 0:
m = "the chunksize needs to be a factor of the number of ensemble members"
raise ValueError(m)
forces_ens = []
start = 0
for _ in range(n_ens // self.chunk_size):
end = start + self.chunk_size
jac_i_fn = make_member_chunk_jac(energy_fn, start, end)
force_i = -jac_i_fn(R, Z, neighbor, box, offsets)
forces_ens.append(force_i)
start = end
n_atoms = R.shape[0]
forces_ens = jnp.array(forces_ens)
forces_ens = np.reshape(forces_ens, (n_ens, n_atoms, 3))
forces_mean = jnp.mean(forces_ens, axis=0)
forces_variance = divisor * fp64_sum((forces_ens - forces_mean) ** 2, axis=0)
prediction["forces"] = forces_mean
prediction["forces_uncertainty"] = jnp.sqrt(forces_variance)
forces_ens = jnp.transpose(forces_ens, (1, 2, 0))
prediction["forces_ensemble"] = forces_ens # n_atoms x 3 x n_members
else:
forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets)
prediction["forces"] = forces_mean
if self.calc_stress:
stress = stress_times_vol(
mean_energy_fn, R, box, Z=Z, neighbor=neighbor, offsets=offsets
)
prediction["stress"] = stress
return prediction