Source code for apax.md.function_transformations

import dataclasses

import jax
import jax.numpy as jnp
from jax_md import quantity


def make_biased_energy_force_fn(bias_fn):
    def biased_energy_force_fn(positions, Z, idx, box, offsets):
        bias_and_grad_fn = jax.value_and_grad(bias_fn, has_aux=True)

        (E_bias, results), neg_F_bias = bias_and_grad_fn(positions, Z, idx, box, offsets)

        if "energy_unbiased" not in results.keys():
            results["energy_unbiased"] = results["energy"]
            results["forces_unbiased"] = results["forces"]

        F_bias = -neg_F_bias
        results["energy"] = results["energy"] + E_bias
        results["forces"] = results["forces"] + F_bias

        return results

    return biased_energy_force_fn


[docs] @dataclasses.dataclass class UncertaintyDrivenDynamics: """ UDD requires an uncertainty aware model. It drives the dynamics towards higher uncertainty regions up to some maximum bias energy. https://doi.org/10.1038/s43588-023-00406-5 Parameters ---------- height : float Maximum bias potential that can be applied width : float Width of the Gaussian bias. """ height: float width: float def apply(self, model): def udd_energy(positions, Z, idx, box, offsets): n_atoms = positions.shape[0] results = model(positions, Z, idx, box, offsets) n_models = results["energy_ensemble"].shape[0] sigma2 = results["energy_uncertainty"] ** 2 gauss = jnp.exp(-sigma2 / (n_models * n_atoms * self.width**2)) E_udd = self.height * (gauss - 1) return E_udd, results udd_energy_force = make_biased_energy_force_fn(udd_energy) return udd_energy_force
[docs] @dataclasses.dataclass class GaussianAcceleratedMolecularDynamics: """ Applies a boost potential to the system that pulls it towards a target energy. https://pubs.acs.org/doi/10.1021/acs.jctc.5b00436 Parameters ---------- energy_target : float Target potential energy below which to apply the boost potential. spring_constant : float Spring constant of the boost potential. """ energy_target: float spring_constant: float def apply(self, model): def gamd_energy(positions, Z, idx, box, offsets): results = model(positions, Z, idx, box, offsets) energy = jnp.clip(results["energy"], a_max=self.energy_target) E_gamd = 0.5 * self.spring_constant * (energy - self.energy_target) ** 2 return E_gamd, results gamd_energy_force = make_biased_energy_force_fn(gamd_energy) return gamd_energy_force
[docs] @dataclasses.dataclass class GlobalCalibration: """ Applies a global calibration to energy and force uncertainties. Energy ensemble predictions are rescaled according to EQ 7 in https://doi.org/10.1063/5.0036522 Parameters ---------- energy_factor : float Global calibration factor by which to scale the energy uncertainty. forces_factor : float Global calibration factor by which to scale the force uncertainties. """ energy_factor: float forces_factor: float def apply(self, model): def calibrated_model(positions, Z, idx, box, offsets): results = model(positions, Z, idx, box, offsets) results["energy_uncertainty"] = ( results["energy_uncertainty"] * self.energy_factor ) Emean = results["energy"] Ei = results["energy_ensemble"] results["energy_ensemble"] = Emean + self.energy_factor * (Ei - Emean) if "forces_uncertainty" in results.keys(): results["forces_uncertainty"] = ( results["forces_uncertainty"] * self.forces_factor ) return results return calibrated_model
[docs] @dataclasses.dataclass class ProcessStress: """ Remove Volume factor from stress predictions. """ def apply(self, model): def corrected_model(positions, Z, idx, box, offsets): results = model(positions, Z, idx, box, offsets) V = quantity.volume(3, box) results = { # We should properly check whether CP2K uses the ASE cell convention # for tetragonal strain, it doesn't matter whether we transpose or not k: val.T / V if k.startswith("stress") else val for k, val in results.items() } return results return corrected_model
available_transformations = { "udd": UncertaintyDrivenDynamics, "gamd": GaussianAcceleratedMolecularDynamics, "global_cal": GlobalCalibration, "process_stress": ProcessStress, }