Source code for apax.layers.properties

import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import Array

from apax.layers.masking import mask_by_atom
from apax.layers.readout import AtomisticReadout
from apax.utils.math import fp64_sum


[docs] def stress_times_vol(energy_fn, position: Array, box, **kwargs) -> Array: """Computes the internal stress of a system multiplied with the box volume. For training purposes. Parameters ---------- energy_fn: A function that computes the energy of the system. This function must take as an argument `perturbation` which perturbs the box shape. Any energy function constructed using `smap` or in `energy.py` with a standard space will satisfy this property. position: An array of particle positions. box: A box specifying the shape of the simulation volume. Used to infer the volume of the unit cell. Returns ------- Array A float specifying the stress of the system. """ dim = position.shape[1] zero = jnp.zeros((dim, dim), position.dtype) zero = 0.5 * (zero + zero.T) identity = jnp.eye(dim, dtype=position.dtype) def U(eps): return energy_fn(position, box=box, perturbation=(identity + eps), **kwargs) dUdV = jax.grad(U) return dUdV(zero)
[docs] class PropertyHead(nn.Module): """ the readout is currently limited to a single number """ pname: str readout: nn.Module = AtomisticReadout() aggregation: str = "none" mode: str = "l0" apply_mask: bool = True
[docs] def setup(self): n_species = 119 scale_init = nn.initializers.constant(1.0) self.scale = self.param( "scale_per_element", scale_init, (n_species, 1), jnp.float64 ) shift_init = nn.initializers.constant(0.0) self.shift_param = self.param( "shift_per_element", shift_init, (n_species, 1), jnp.float64 )
def __call__(self, g, R, dr_vec, Z, idx, box): h = jax.vmap(self.readout)(g) is_ensemble = False if jnp.size(h, axis=1) > 1: # ensemble detected is_ensemble = True n_ens = jnp.size(h, axis=1) h = h[..., None] h = jnp.transpose(h, (1, 0, 2)) p_i = h * self.scale[Z] + self.shift_param[Z] if self.mode == "l0": p_i = p_i elif self.mode == "l1": Rc = R - jnp.mean(R, axis=0, keepdims=True) norm = jnp.linalg.norm(Rc, axis=1, keepdims=True) r_hat = jnp.where(norm > 0, Rc / norm, 0.0) p_i = p_i * r_hat elif self.mode == "symmetric_l2": Rc = R - jnp.mean(R, axis=0, keepdims=True) norm = jnp.linalg.norm(Rc, axis=1, keepdims=True) r_hat = jnp.where(norm > 0, Rc / norm, 0.0) r_rt = jnp.einsum("ni, nj -> nij", r_hat, r_hat) p_i = p_i[..., None] * r_rt elif self.mode == "symmetric_traceless_l2": Rc = R - jnp.mean(R, axis=0, keepdims=True) norm = jnp.linalg.norm(Rc, axis=1, keepdims=True) r_hat = jnp.where(norm > 0, Rc / norm, 0.0) r_rt = jnp.einsum("ni, nj -> nij", r_hat, r_hat) I = jnp.eye(3) symmetrized = 3 * r_rt - I p_i = p_i[..., None] * symmetrized else: raise KeyError("unknown symmetry option") if is_ensemble: p_i = jnp.swapaxes(p_i, 0, 1) # natoms, nens, features... if self.apply_mask: p_i = mask_by_atom(p_i, Z) if self.aggregation == "none": result = p_i elif self.aggregation == "sum": result = fp64_sum(p_i, axis=0) elif self.aggregation == "mean": natoms = R.shape[0] result = fp64_sum(p_i, axis=0) / natoms else: raise KeyError("unknown aggregation") output = {self.pname: result} if is_ensemble: divisor = 1 / (n_ens - 1) if self.aggregation == "none": result = jnp.swapaxes(result, 0, 1) mean = jnp.mean(result, axis=0) uncertainty = divisor * fp64_sum((mean - result) ** 2, axis=0) output[self.pname] = mean output[self.pname + "_uncertainty"] = uncertainty return output