Source code for apax.md.bias
from typing import Any, Callable, Union
import jax
import jax.numpy as jnp
from pydantic import BaseModel, TypeAdapter
from apax.utils.jax_md_reduced.space import distance
[docs]
class BiasEnergyBase(BaseModel):
def energy(self, R, neighbor, box, pertubation=None):
raise NotImplementedError()
def apply_bias_energy(bias: BiasEnergyBase, model) -> Callable[..., dict[str, Any]]:
# Function signature:
# Array, Array, Array, pertubation -> float
def energy_fn(R, neighbor, box, pertubation=None):
energy = model(R=R, neighbor=neighbor, box=box)
E_bias = bias.energy(R, neighbor, box, pertubation=pertubation)
return energy + E_bias
return energy_fn
def apply_bias_auxiliary(bias: BiasEnergyBase, model) -> Callable[..., dict[str, Any]]:
def aux_fn(R, Z, neighbor, box, offsets):
E_bias, neg_F_bias = jax.value_and_grad(bias.energy)(R, neighbor, box)
prediction = model(R=R, Z=Z, neighbor=neighbor, box=box, offsets=offsets)
if "energy_unbiased" not in prediction:
prediction["energy_unbiased"] = prediction["energy"]
prediction["forces_unbiased"] = prediction["forces"]
for key in prediction:
if "unbiased" in key or "uncertainty" in key:
continue
if "forces" in key:
if "ensemble" in key:
prediction[key] = prediction[key] - neg_F_bias[:, :, None]
else:
prediction[key] = prediction[key] - neg_F_bias
elif "energy" in key:
prediction[key] = prediction[key] + E_bias
# if "ensemble" in key:
# else:
# prediction[key] = prediction[key] + E_bias
return prediction
return aux_fn
[docs]
class SphericalWall(BiasEnergyBase):
radius: float
spring_constant: float
def energy(self, R, neighbor, box, pertubation=None):
distance_outside_radius = jnp.clip(distance(R) - self.radius, min=0.0)
return 0.5 * self.spring_constant * jnp.sum(distance_outside_radius**2)
BiasEnergies = TypeAdapter(Union[SphericalWall]).validate_python