Source code for apax.layers.properties

import jax
import jax.numpy as jnp
from jax import Array


[docs] def stress_times_vol(energy_fn, position: Array, box, **kwargs) -> Array: """Computes the internal stress of a system multiplied with the box volume. For training purposes. Parameters ---------- energy_fn: A function that computes the energy of the system. This function must take as an argument `perturbation` which perturbs the box shape. Any energy function constructed using `smap` or in `energy.py` with a standard space will satisfy this property. position: An array of particle positions. box: A box specifying the shape of the simulation volume. Used to infer the volume of the unit cell. Returns ------- Array A float specifying the stress of the system. """ dim = position.shape[1] zero = jnp.zeros((dim, dim), position.dtype) zero = 0.5 * (zero + zero.T) identity = jnp.eye(dim, dtype=position.dtype) def U(eps): return energy_fn(position, box=box, perturbation=(identity + eps), **kwargs) dUdV = jax.grad(U) return dUdV(zero)