Source code for apax.md.dynamics_checks

from typing import Literal, Union

import jax.numpy as jnp
from pydantic import BaseModel, TypeAdapter

from apax.utils.jax_md_reduced.space import distance


[docs] class DynamicsCheckBase(BaseModel): def check(self, predictions, positions, box): pass
[docs] class EnergyUncertaintyCheck(DynamicsCheckBase, extra="forbid"): name: Literal["energy_uncertainty"] = "energy_uncertainty" threshold: float per_atom: bool = True def check(self, predictions, positions, box): if "energy_uncertainty" not in predictions.keys(): m = "No energy uncertainty found. Are you using a model ensemble?" raise ValueError(m) energy_uncertainty = predictions["energy_uncertainty"] if self.per_atom: n_atoms = predictions["forces"].shape[0] energy_uncertainty = energy_uncertainty / n_atoms check_passed = jnp.all(energy_uncertainty < self.threshold) return check_passed
[docs] class ForceUncertaintyCheck(DynamicsCheckBase, extra="forbid"): name: Literal["forces_uncertainty"] = "forces_uncertainty" threshold: float def check(self, predictions, positions, box): if "forces_uncertainty" not in predictions.keys(): m = "No force uncertainties found. Are you using a model ensemble?" raise ValueError(m) forces_uncertainty = predictions["forces_uncertainty"] check_passed = jnp.all(forces_uncertainty < self.threshold) return check_passed
[docs] class ReflectionCheck(DynamicsCheckBase, extra="forbid"): name: Literal["reflection"] = "reflection" cutoff_plane_height: float def check(self, predictions, positions, box): cartesian = positions @ box z_pos = cartesian[:, 2] check_passed = jnp.all(z_pos < self.cutoff_plane_height) return check_passed
[docs] class RadiusCheck(DynamicsCheckBase, extra="forbid"): name: Literal["radius"] = "radius" cutoff_radius: float def check(self, predictions, positions, box): cartesian = positions @ box radius = distance(cartesian) check_passed = jnp.all(radius < self.cutoff_radius) return check_passed
DynamicsChecks = TypeAdapter( Union[EnergyUncertaintyCheck, ForceUncertaintyCheck, ReflectionCheck, RadiusCheck] ).validate_python