Source code for apax.config.md_config

import os
from typing import Literal, Union

import yaml
from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt
from typing_extensions import Annotated

from apax.utils.helpers import APAX_PROPERTIES


class ConstantTempSchedule(BaseModel, extra="forbid"):
    """Constant temperature schedule.

    Attributes
    ----------
    name: str
        Identifier of the temperature schedule.
    T0 : PositiveFloat, default = 298.15
        Initial temperature in Kelvin (K).
    """

    name: Literal["constant"] = "constant"
    T0: PositiveFloat = 298.15  # K

    def get_schedule(self):
        from apax.md.schedules import ConstantTSchedule

        return ConstantTSchedule(self.T0)


class PiecewiseLinearTempSchedule(ConstantTempSchedule, extra="forbid"):
    """Piecewise linear temperature schedule.
    Temperature is linearly interpolated between T0 and the supplied
    values at the specified time steps.

    Attributes
    ----------
    temperatures: list[PositiveFloat]
        List of temperatures to interpolate between.
    durations: list[PositiveInt]
        Duration (in time steps) of the interpolation between two
        subsequent values of `temperatures`.
    """

    name: Literal["piecewise"] = "piecewise"
    temperatures: list[PositiveFloat]
    durations: list[PositiveInt]

    def get_schedule(self):
        from apax.md.schedules import PieceWiseLinearTSchedule

        schedule = PieceWiseLinearTSchedule(
            self.T0,
            self.temperatures,
            self.durations,
        )
        return schedule


class OscillatingRampTempSchedule(ConstantTempSchedule, extra="forbid"):
    """Combination of a linear interpolation between T0 and Tend and a temperature oscillation.
    Mostly for sampling purposes.

    Attributes
    ----------
    Tend: PositiveFloat
        Final temperature in Kelvin.
    amplitude: PositiveFloat
        Amplitude of temperature oscilaltions.
    num_oscillations: PositiveInt
        Number of oscillations to occur during the simulation.
    total_steps: PositiveInt
        Total steps of the schedule. Afterwards, Tend will be kept.
    """

    name: Literal["oscillating_ramp"] = "oscillating_ramp"
    Tend: PositiveFloat
    amplitude: PositiveFloat
    num_oscillations: PositiveInt
    total_steps: PositiveInt

    def get_schedule(self):
        from apax.md.schedules import OscillatingRampTSchedule

        schedule = OscillatingRampTSchedule(
            self.T0,
            self.Tend,
            self.amplitude,
            self.num_oscillations,
            self.total_steps,
        )
        return schedule


TemperatureSchedule = Union[
    ConstantTempSchedule, PiecewiseLinearTempSchedule, OscillatingRampTempSchedule
]


class NHCOptions(BaseModel, extra="forbid"):
    """
    Options for Nose-Hoover chain thermostat.

    Parameters
    ----------
    chain_length : PositiveInt, default = 3
        Number of thermostats in the chain.
    chain_steps : PositiveInt, default = 2
        Number of steps per chain.
    sy_steps : PositiveInt, default = 3
        Number of steps for Suzuki-Yoshida integration.
    tau : PositiveFloat, default = 100
        Relaxation time parameter.
    """

    chain_length: PositiveInt = 3
    chain_steps: PositiveInt = 2
    sy_steps: PositiveInt = 3
    tau: PositiveFloat = 100


class Integrator(BaseModel, extra="forbid"):
    """
    Molecular dynamics integrator options.

    Parameters
    ----------
    name : str
        Name of the ensemble.
    dt: PositiveFloat, default = 0.5
        Time step in femto seconds.
    temperature_schedule: TemperatureSchedule
        Temperature schedule to use throughout the simulation.
        For NVE, it is only used for velocity initialization and
        disregarded at subsequent steps.
    """

    name: str
    dt: PositiveFloat = 0.5  # fs
    temperature_schedule: TemperatureSchedule = Field(
        ConstantTempSchedule(name="constant", T0=298.15), discriminator="name"
    )


class NVEOptions(Integrator, extra="forbid"):
    """
    Options for NVE ensemble simulations.

    Attributes
    ----------
    name : Literal["nve"]
        Name of the ensemble.
    """

    name: Literal["nve"]


class NVTOptions(NVEOptions, extra="forbid"):
    """
    Options for NVT ensemble simulations.

    Parameters
    ----------
    name : Literal["nvt"]
        Name of the ensemble.
    thermostat_chain : NHCOptions, default = NHCOptions()
        Thermostat chain options.
    """

    name: Literal["nvt"]
    thermostat_chain: NHCOptions = NHCOptions()


[docs] class NPTOptions(NVTOptions, extra="forbid"): """ Options for NPT ensemble simulations. Parameters ---------- name : Literal["npt"] Name of the ensemble. pressure : PositiveFloat, default = 1.01325 Pressure in bar. barostat_chain : NHCOptions, default = NHCOptions(tau=1000) Barostat chain options. """ name: Literal["npt"] pressure: PositiveFloat = 1.01325 # bar barostat_chain: NHCOptions = NHCOptions(tau=1000)
class EnergyUncertaintyCheck(BaseModel, extra="forbid"): name: Literal["energy_uncertainty"] = "energy_uncertainty" threshold: PositiveFloat per_atom: bool = True class ForcesUncertaintyCheck(BaseModel, extra="forbid"): name: Literal["forces_uncertainty"] = "forces_uncertainty" threshold: PositiveFloat class ReflectionCheck(BaseModel, extra="forbid"): name: Literal["reflection"] = "reflection" cutoff_plane_height: float DynamicsCheck = Annotated[ Union[EnergyUncertaintyCheck, ForcesUncertaintyCheck, ReflectionCheck], Field(discriminator="name"), ] class SphericalWallEnergy(BaseModel, extra="forbid"): """Bias energy to prevent atoms from moving outside of radius from origin. Attributes: name (Literal['spherical_wall']) radius (PositiveFloat): radius where to put spherical wall in Angstrom spring_constant (float): spring constant in eV/Angstrom^2 """ name: Literal["spherical_wall"] = "spherical_wall" radius: PositiveFloat spring_constant: float BiasEnergy = Annotated[ Union[SphericalWallEnergy], Field(discriminator="name"), ] class FixAtomsConstraint(BaseModel, extra="forbid"): """Fix atoms in cartesian coordinates. Attributes: name (Literal['fixatoms']) indices (list[int]): list of atom indices to fix """ name: Literal["fixatoms"] = "fixatoms" indices: list[int] class FixCenterOfMassConstraint(BaseModel, extra="forbid"): """Fix the center of mass of the system. Attributes: name (Literal['fixcenterofmass']) position (Union[Literal['initial', 'origin'], list[float]]): position where to keep the center of mass fixed. """ name: Literal["fixcenterofmass"] = "fixcenterofmass" position: Union[Literal["initial", "origin"], list[float]] = "initial" class FixLayerConstraint(BaseModel, extra="forbid"): name: Literal["fixlayer"] = "fixlayer" upper_limit: float lower_limit: float Constraint = Annotated[ Union[FixAtomsConstraint, FixCenterOfMassConstraint, FixLayerConstraint], Field(discriminator="name"), ] class H5MDOptions(BaseModel, extra="forbid"): """ Options for configuring H5MD (Hierarchical Data Format for Molecular Data) file output. Attributes ---------- compression: str | None | Compression algorithm to use for the HDF5 file (e.g., "gzip"). Defaults to "gzip". compression_opts: int | Compression level or options for the chosen algorithm. Defaults to 4. store: Literal["time", "linear"] | Storage mode for the data. Can be "time" or "linear". Defaults to "time". author: str | Name of the author of the file. Defaults to "N/A". author_email: str | Email address of the author. Defaults to "N/A". """ compression: str | None = "gzip" compression_opts: int = 4 store: Literal["time", "linear"] = "time" author: str = "N/A" author_email: str = "N/A"
[docs] class MDConfig(BaseModel, frozen=True, extra="forbid"): """ Configuration for a NHC molecular dynamics simulation. Full config :ref:`here <md_config>`: Parameters ---------- seed : int, default = 1 | Random seed for momentum initialization. ensemle : | Options for integrating the EoM and which ensemble to use. dt : float, default = 0.5 | Time step in fs. duration : float, required | Total simulation time in fs. n_inner : int, default = 500 | Number of compiled simulation steps (i.e. number of iterations of the | `jax.lax.fori_loop` loop). Also determines atoms buffer size. sampling_rate : int, default = 10 | Interval between saving frames. buffer_size : int, default = 2500 | Number of collected frames to be dumped at once. dr_threshold : float, default = 0.5 | Skin of the neighborlist. extra_capacity : int, default = 0 | JaxMD allocates a maximal number of neighbors. This argument lets you add | additional capacity to avoid recompilation. The default is usually fine. biases : list[BiasEnergy] | List of bias energies. Currently a spherical wall potential is available. dynamics_checks : list[DynamicsCheck] | List of termination criteria. Currently energy and force uncertainty | are available constraints : list[Constraint] | List of constraints. Currently ... properties : list[str] | Whitelist of properties to be saved in the trajectory. | This does not effect what the model will calculate, e.g.. | an ensemble will still calculate uncertainties. initial_structure : str, required | Path to the starting structure of the simulation. sim_dir : str, default = "." | Directory where simulation file will be stored. traj_name : str, default = "md.h5" | Name of the trajectory file. restart : bool, default = True | Whether the simulation should restart from the latest configuration in | `traj_name`. wrapped : bool, default = True | Whether the atoms in the simulation should wrapped back into the box. checkpoint_interval : int, default = 50_000 | Number of time steps between saving full simulation state checkpoints. | These will be loaded with the `restart` option. disable_pbar : bool, False | Disables the MD progressbar. """ seed: int = 1 # https://docs.pydantic.dev/latest/usage/types/unions/#discriminated-unions-aka-tagged-unions ensemble: Union[NVEOptions, NVTOptions, NPTOptions] = Field( NVTOptions(name="nvt"), discriminator="name" ) duration: PositiveFloat n_inner: PositiveInt = 500 sampling_rate: PositiveInt = 10 buffer_size: PositiveInt = 2500 dr_threshold: PositiveFloat = 0.5 extra_capacity: NonNegativeInt = 0 biases: list[BiasEnergy] = [] dynamics_checks: list[DynamicsCheck] = [] constraints: list[Constraint] = [] properties: list[str] = APAX_PROPERTIES h5md_options: H5MDOptions = H5MDOptions() initial_structure: str load_momenta: bool = False sim_dir: str = "." traj_name: str = "md.h5" restart: bool = True wrapped: bool = True checkpoint_interval: int = 50_000 disable_pbar: bool = False
[docs] def dump_config(self): """ Writes the current config file to the MD directory. """ with open(os.path.join(self.sim_dir, "md_config.yaml"), "w") as conf: yaml.dump(self.model_dump(), conf, default_flow_style=False)