Source code for apax.config.md_config

import os

# from types import UnionType
from typing import Literal, Union

import yaml
from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt


class NHCOptions(BaseModel, extra="forbid"):
    """
    Options for Nose-Hoover chain thermostat.

    Parameters
    ----------
    chain_length : PositiveInt, default = 3
        Number of thermostats in the chain.
    chain_steps : PositiveInt, default = 2
        Number of steps per chain.
    sy_steps : PositiveInt, default = 3
        Number of steps for Suzuki-Yoshida integration.
    tau : PositiveFloat, default = 100
        Relaxation time parameter.
    """

    chain_length: PositiveInt = 3
    chain_steps: PositiveInt = 2
    sy_steps: PositiveInt = 3
    tau: PositiveFloat = 100


class Integrator(BaseModel, extra="forbid"):
    """
    Molecular dynamics integrator options.

    Parameters
    ----------
    dt : PositiveFloat, default = 0.5
        Time step size in femtoseconds (fs).
    """

    dt: PositiveFloat = 0.5  # fs


class NVEOptions(Integrator, extra="forbid"):
    """
    Options for NVE ensemble simulations.

    Attributes
    ----------
    name : Literal["nve"]
        Name of the ensemble.
    init_temperature : PositiveFloat, default = 298.15
        Initialisation temperature in Kelvin (K).

    """

    name: Literal["nve"]
    init_temperature: PositiveFloat = 298.15  # K


class NVTOptions(Integrator, extra="forbid"):
    """
    Options for NVT ensemble simulations.

    Parameters
    ----------
    name : Literal["nvt"]
        Name of the ensemble.
    temperature : PositiveFloat, default = 298.15
        Temperature in Kelvin (K).
    thermostat_chain : NHCOptions, default = NHCOptions()
        Thermostat chain options.
    """

    name: Literal["nvt"]
    temperature: PositiveFloat = 298.15  # K
    thermostat_chain: NHCOptions = NHCOptions()


[docs] class NPTOptions(NVTOptions, extra="forbid"): """ Options for NPT ensemble simulations. Parameters ---------- name : Literal["npt"] Name of the ensemble. pressure : PositiveFloat, default = 1.01325 Pressure in bar. barostat_chain : NHCOptions, default = NHCOptions(tau=1000) Barostat chain options. """ name: Literal["npt"] pressure: PositiveFloat = 1.01325 # bar barostat_chain: NHCOptions = NHCOptions(tau=1000)
[docs] class MDConfig(BaseModel, frozen=True, extra="forbid"): """ Configuration for a NHC molecular dynamics simulation. Full config :ref:`here <md_config>`: Parameters ---------- seed : int, default = 1 | Random seed for momentum initialization. temperature : float, default = 298.15 | Temperature of the simulation in Kelvin. dt : float, default = 0.5 | Time step in fs. duration : float, required | Total simulation time in fs. n_inner : int, default = 100 | Number of compiled simulation steps (i.e. number of iterations of the | `jax.lax.fori_loop` loop). Also determines atoms buffer size. sampling_rate : int, default = 10 | Interval between saving frames. buffer_size : int, default = 100 | Number of collected frames to be dumped at once. dr_threshold : float, default = 0.5 | Skin of the neighborlist. extra_capacity : int, default = 0 | JaxMD allocates a maximal number of neighbors. This argument lets you add | additional capacity to avoid recompilation. The default is usually fine. initial_structure : str, required | Path to the starting structure of the simulation. sim_dir : str, default = "." | Directory where simulation file will be stored. traj_name : str, default = "md.h5" | Name of the trajectory file. restart : bool, default = True | Whether the simulation should restart from the latest configuration in | `traj_name`. checkpoint_interval : int, default = 50_000 | Number of time steps between saving full simulation state checkpoints. | These will be loaded with the `restart` option. disable_pbar : bool, False | Disables the MD progressbar. """ seed: int = 1 # https://docs.pydantic.dev/latest/usage/types/unions/#discriminated-unions-aka-tagged-unions ensemble: Union[NVEOptions, NVTOptions, NPTOptions] = Field( NVTOptions(name="nvt"), discriminator="name" ) duration: PositiveFloat n_inner: PositiveInt = 100 sampling_rate: PositiveInt = 10 buffer_size: PositiveInt = 100 dr_threshold: PositiveFloat = 0.5 extra_capacity: NonNegativeInt = 0 initial_structure: str load_momenta: bool = False sim_dir: str = "." traj_name: str = "md.h5" restart: bool = True checkpoint_interval: int = 50_000 disable_pbar: bool = False
[docs] def dump_config(self): """ Writes the current config file to the MD directory. """ with open(os.path.join(self.sim_dir, "md_config.yaml"), "w") as conf: yaml.dump(self.model_dump(), conf, default_flow_style=False)