Source code for apax.md.constraints
from typing import Callable, Literal, Union
import jax.numpy as jnp
from pydantic import BaseModel, TypeAdapter
from apax.md.sim_utils import System
from apax.utils.math import center_of_mass
[docs]
class ConstraintBase(BaseModel):
"""Base class for constraints.
Constraints work by implementing a create method.
This method accepts a reference state which to
compare to subsequent ones and returns a callable which applies
the constraint during simulations.
"""
def create(self, system) -> Callable:
pass
[docs]
class FixAtoms(ConstraintBase, extra="forbid"):
name: Literal["fixatoms"] = "fixatoms"
indices: list[int]
def create(self, system: System) -> Callable:
indices = jnp.array(self.indices, dtype=jnp.int64)
ref_position = system.positions[indices]
def fn(state):
position = state.position
position = position.at[indices].set(ref_position)
force = state.force
zero_force = jnp.zeros_like(ref_position)
force = force.at[indices].set(zero_force)
momenta = state.momentum
zero_momenta = jnp.zeros_like(ref_position)
momenta = momenta.at[indices].set(zero_momenta)
state = state.set(position=position, force=force, momentum=momenta)
return state
return fn, indices
[docs]
class FixCenterOfMass(ConstraintBase, extra="forbid"):
name: Literal["fixcenterofmass"] = "fixcenterofmass"
position: Union[Literal["initial", "origin"], list[float]] = "initial"
def create(self, system: System) -> Callable:
if isinstance(self.position, str):
if self.position.lower() == "initial":
ref_com = center_of_mass(system.positions, system.masses)
elif self.position.lower() == "origin":
ref_com = jnp.array([0, 0, 0])
else:
ref_com = jnp.array(self.position)
def fn(state):
masses = state.mass[:, 0]
position = state.position
position += ref_com - center_of_mass(position, masses)
momenta = state.momentum
velocity_com = jnp.sum(momenta, axis=0) / jnp.sum(masses)
momenta -= masses[:, None] * velocity_com
# Eqs. (3) and (7) in https://doi.org/10.1021/jp9722824
# Have not explicitly tested this yet.
force = state.force
force -= (
masses[:, None]
/ jnp.sum(masses**2)
* jnp.sum(masses[:, None] * force, axis=0)
)
state = state.set(position=position, force=force, momentum=momenta)
return state
# We return 0 as a constrained idx, to make sure that the
# integrator knows that we have 3 dof less.
return fn, [0]
[docs]
class FixRotation(ConstraintBase, extra="forbid"):
name: Literal["fixrotation"] = "fixrotation"
def create(self, system: System) -> Callable:
raise NotImplementedError()
class FixLayer(ConstraintBase, extra="forbid"):
""" """
name: Literal["fixlayer"] = "fixlayer"
upper_limit: float
lower_limit: float
def create(self, system) -> Callable:
if jnp.any(system.box > 10e-4):
cart_pos = system.positions @ system.box
z_coordinates = cart_pos[:, 2]
indices = jnp.where(
(self.lower_limit <= z_coordinates) & (z_coordinates <= self.upper_limit)
)[0]
ref_position = system.positions[indices]
def fn(state):
position = state.position
position = position.at[indices].set(ref_position)
force = state.force
zero_force = jnp.zeros_like(ref_position)
force = force.at[indices].set(zero_force)
momenta = state.momentum
zero_momenta = jnp.zeros_like(ref_position)
momenta = momenta.at[indices].set(zero_momenta)
state = state.set(position=position, force=force, momentum=momenta)
return state
return fn, indices
Constraint = TypeAdapter(
Union[FixAtoms, FixCenterOfMass, FixRotation, FixLayer]
).validate_python