import logging
import time
from functools import partial
from pathlib import Path
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
from ase import units
from ase.io import read
from jax.experimental import io_callback
from jax_md import partition, quantity, simulate, space
from tqdm import trange
from tqdm.contrib.logging import logging_redirect_tqdm
from apax.config import Config, MDConfig, parse_config
from apax.config.md_config import Integrator
from apax.md.ase_calc import make_ensemble, maybe_vmap
from apax.md.bias import (
BiasEnergies,
apply_bias_auxiliary,
apply_bias_energy,
)
from apax.md.constraints import Constraint, ConstraintBase
from apax.md.dynamics_checks import DynamicsCheckBase, DynamicsChecks
from apax.md.io import H5TrajHandler, TrajHandler, truncate_trajectory_to_checkpoint
from apax.md.md_checkpoint import load_md_state
from apax.md.sim_utils import SimulationFunctions, System
from apax.train.checkpoints import (
canonicalize_energy_model_parameters,
restore_parameters,
)
from apax.train.run import setup_logging
log = logging.getLogger(__name__)
def create_energy_fn(
model,
params,
numbers,
n_models,
shallow=False,
):
def full_ensemble(params, R, Z, neighbor, box, offsets, perturbation=None, **kwargs):
vmodel = jax.vmap(model, (0, None, None, None, None, None, None), 0)
energies, _ = vmodel(params, R, Z, neighbor, box, offsets, perturbation)
energy = jnp.mean(energies)
return energy
def shallow_ensemble(
params, R, Z, neighbor, box, offsets, perturbation=None, **kwargs
):
energies, _ = model(params, R, Z, neighbor, box, offsets, perturbation)
energy = jnp.mean(energies)
return energy
def single_model(params, R, Z, neighbor, box, offsets, perturbation=None, **kwargs):
energy, _ = model(params, R, Z, neighbor, box, offsets, perturbation)
return energy
if n_models > 1:
if shallow:
energy_fn = shallow_ensemble
else:
energy_fn = full_ensemble
else:
energy_fn = single_model
energy_fn = partial(
energy_fn,
params,
Z=numbers,
offsets=jnp.array([0.0, 0.0, 0.0]),
)
return energy_fn
def heights_of_box_sides(box):
heights = []
for i in range(len(box)):
for j in range(i + 1, len(box)):
area = np.linalg.norm(np.cross(box[i], box[j]))
height = area / np.linalg.norm(box[i])
heights.append(height)
height = area / np.linalg.norm(box[j])
heights.append(height)
return np.array(heights)
def nbr_update_options_default(state):
return {}
def nbr_update_options_npt(state):
box = simulate.npt_box(state)
return {"box": box}
def get_ensemble(ensemble: Integrator, sim_fns, constaint_idxs=None):
energy, shift = sim_fns.energy_fn, sim_fns.shift_fn
dt = ensemble.dt * units.fs
nbr_options = nbr_update_options_default
kT = ensemble.temperature_schedule.get_schedule()
if ensemble.name == "nve":
init_fn, apply_fn = simulate.nve(energy, shift, kT(0), dt)
elif ensemble.name == "nvt":
thermostat_chain = dict(ensemble.thermostat_chain)
thermostat_chain["tau"] *= dt
init_fn, apply_fn = simulate.nvt_nose_hoover(
energy,
shift,
dt,
kT(0),
constrained_idxs=constaint_idxs,
)
elif ensemble.name == "npt":
if constaint_idxs:
raise NotImplementedError(
"Constraining atoms in NPT simulations is not implemented."
)
pressure = ensemble.pressure * units.bar
thermostat_chain = dict(ensemble.thermostat_chain)
barostat_chain = dict(ensemble.barostat_chain)
thermostat_chain["tau"] *= dt
barostat_chain["tau"] *= dt
init_fn, apply_fn = simulate.npt_nose_hoover(
energy,
shift,
dt,
pressure,
kT(0),
thermostat_kwargs=thermostat_chain,
barostat_kwargs=barostat_chain,
)
nbr_options = nbr_update_options_npt
else:
raise NotImplementedError(
"Only the NVE and Nose Hoover NVT/NPT thermostats are currently interfaced."
)
return init_fn, apply_fn, kT, nbr_options
def handle_checkpoints(state, step, system, load_momenta, ckpt_dir, should_load_ckpt):
if load_momenta and not should_load_ckpt:
log.info("loading momenta from starting configuration")
state = state.set(momentum=system.momenta)
elif should_load_ckpt:
state, step = load_md_state(state, ckpt_dir.resolve())
return state, step
def create_evaluation_functions(traj_handler, aux_fn, Z, neighbor, dynamics_checks):
offsets = jnp.zeros((neighbor.idx.shape[1], 3))
def on_eval(state, neighbor, box, nbr_kwargs):
positions = state.position
predictions = aux_fn(positions, Z, neighbor, box, offsets)
all_checks_passed = True
for check in dynamics_checks:
check_passed = check.check(predictions, positions, box)
all_checks_passed = all_checks_passed & check_passed
io_callback(traj_handler.step, None, (state, predictions, nbr_kwargs))
return all_checks_passed
def no_eval(state, neighbor, box, nbr_kwargs):
all_checks_passed = True
return all_checks_passed
return on_eval, no_eval
def check_unique_idxs(constraind_idxs):
unique_idxs = []
seen_idxs = set()
for idxs in constraind_idxs:
for val in idxs:
val = int(val)
if val not in seen_idxs:
seen_idxs.add(val)
unique_idxs.append(val)
return unique_idxs
def create_constraint_function(constraints: list[ConstraintBase], system):
constrain_fns = []
constraind_idxs = []
for constraint in constraints:
constrain_fn, idx = constraint.create(system)
constrain_fns.append(constrain_fn)
constraind_idxs.append(idx)
if constraind_idxs:
constraind_idxs = check_unique_idxs(constraind_idxs)
def apply_constraints(state):
for fn in constrain_fns:
state = fn(state)
return state
return apply_constraints, constraind_idxs
def check_for_nans(state, step):
if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)):
raise ValueError(f"NaN encountered, simulation aborted after {step + 1} steps.")
def handle_overflow(neighbor_fn, state, traj_handler, step):
with logging_redirect_tqdm():
log.warning("step %d: neighbor list overflowed, reallocating.", step)
traj_handler.reset_buffer()
return neighbor_fn.allocate(state.position)
def maybe_save_checkpoint(mngr, state, step, checkpoint_interval, sim_time_per_step):
if step % checkpoint_interval == 0:
with logging_redirect_tqdm():
log.info(
"saving checkpoint at %.1f ps - step: %d",
step * sim_time_per_step,
step,
)
mngr.save(step, args=ocp.args.StandardSave({"state": state, "step": step}))
def maybe_update_pbar(
sim_pbar, step, pbar_update_freq, pbar_increment, current_temperature
):
if step % pbar_update_freq == 0:
sim_pbar.set_postfix(T=f"{current_temperature:.1f} K")
sim_pbar.update(pbar_increment)
[docs]
def run_sim(
system: System,
sim_fns: SimulationFunctions,
ensemble,
sim_dir: Path,
n_steps: int,
n_inner: int,
extra_capacity: int,
rng_key: int,
traj_handler: TrajHandler,
sampling_rate: int = 10,
load_momenta: bool = False,
restart: bool = True,
checkpoint_interval: int = 50_000,
dynamics_checks: list[DynamicsCheckBase] = [],
constraints: list[ConstraintBase] = [],
disable_pbar: bool = False,
):
"""
Performs NVT MD.
Parameters
----------
ensemble :
Thermodynamic ensemble.
n_steps : int
Total time steps.
n_inner : int
JIT compiled inner loop. Also determines atoms buffer size.
extra_capacity : int
Extra capacity for the neighborlist.
rng_key : int
RNG key used to initialize the simulation.
restart : bool, default = True
Whether a checkpoint should be loaded. No implemented yet.
checkpoint_interval : int, default = 50_000
Number of time steps between saving
full simulation state checkpoints.
sim_dir : Path
Directory where the trajectory and simulation checkpoints will be saved.
"""
neighbor_fn = sim_fns.neighbor_fn
ckpt_dir = sim_dir / "ckpts"
ckpt_dir.mkdir(exist_ok=True)
apply_constraints, constrained_idxs = create_constraint_function(
constraints,
system,
)
log.info("initializing simulation")
init_fn, apply_fn, kT, nbr_options = get_ensemble(ensemble, sim_fns, constrained_idxs)
neighbor = sim_fns.neighbor_fn.allocate(
system.positions, extra_capacity=extra_capacity
)
state = init_fn(
rng_key,
system.positions,
box=system.box,
mass=system.masses,
neighbor=neighbor,
)
step = 0
options = ocp.CheckpointManagerOptions(max_to_keep=1, save_interval_steps=1)
mngr = ocp.CheckpointManager(ckpt_dir.resolve(), options=options)
ckpts_exist = mngr.latest_step() is not None
should_load_ckpt = restart and ckpts_exist
state, step = handle_checkpoints(
state, step, system, load_momenta, ckpt_dir, should_load_ckpt
)
if should_load_ckpt:
length = step * n_inner
truncate_trajectory_to_checkpoint(traj_handler.traj_path, length)
initial_step = step # used for measuring time correctly
n_outer = int(np.ceil(n_steps / n_inner))
pbar_update_freq = int(np.ceil(500 / n_inner))
pbar_increment = n_inner * pbar_update_freq
on_eval, no_eval = create_evaluation_functions(
traj_handler,
sim_fns.auxiliary_fn,
system.atomic_numbers,
neighbor,
dynamics_checks,
)
@jax.jit
def sim(state, outer_step, neighbor): # TODO make more modular
def body_fn(i, state):
state, outer_step, neighbor, all_checks_passed = state
step = i + outer_step * n_inner
apply_fn_kwargs = {}
if isinstance(state, simulate.NPTNoseHooverState):
box = state.box
else:
box = system.box
apply_fn_kwargs = {"box": box}
apply_fn_kwargs["kT"] = kT(step) # Get current Temperature
state = apply_fn(state, neighbor=neighbor, **apply_fn_kwargs)
state = apply_constraints(state)
nbr_kwargs = nbr_options(state)
neighbor = neighbor.update(state.position, **nbr_kwargs)
condition = step % sampling_rate == 0
checks_passed = jax.lax.cond(
condition, on_eval, no_eval, state, neighbor, box, nbr_kwargs
)
all_checks_passed = all_checks_passed & checks_passed
return state, outer_step, neighbor, all_checks_passed
all_checks_passed = True
state, outer_step, neighbor, all_checks_passed = jax.lax.fori_loop(
0, n_inner, body_fn, (state, outer_step, neighbor, all_checks_passed)
)
current_temperature = (
quantity.temperature(velocity=state.velocity, mass=state.mass) / units.kB
)
return state, neighbor, current_temperature, all_checks_passed
start = time.time()
total_sim_time = n_steps * ensemble.dt / 1000
log.info("running simulation for %.1f ps", total_sim_time)
initial_time = step * n_inner
sim_pbar = trange(
initial_time,
n_steps,
initial=initial_time,
total=n_steps,
desc="Simulation",
ncols=100,
disable=disable_pbar,
leave=True,
)
sim_time_per_step = n_inner * ensemble.dt / 1000
with mngr:
while step < n_outer:
new_state, neighbor, current_temperature, all_checks_passed = sim(
state, step, neighbor
)
check_for_nans(state, step)
if not all_checks_passed:
with logging_redirect_tqdm():
log.critical(
"One or more dynamics checks failed at step: %d", step + 1
)
break
if neighbor.did_buffer_overflow:
neighbor = handle_overflow(neighbor_fn, state, traj_handler, step)
continue
state = new_state
step += 1
maybe_save_checkpoint(
mngr, state, step, checkpoint_interval, sim_time_per_step
)
maybe_update_pbar(
sim_pbar, step, pbar_update_freq, pbar_increment, current_temperature
)
# In case of mismatch update freq and n_steps, we can set it to 100% manually
sim_pbar.update(n_steps - sim_pbar.n)
sim_pbar.close()
ckpt = {"state": state, "step": step}
mngr.save(step, args=ocp.args.StandardSave(ckpt))
traj_handler.write()
traj_handler.close()
end = time.time()
elapsed_wall_time = end - start
elapsed_sim_time = (step - initial_step) * n_inner * ensemble.dt / 1000
ps_per_s = elapsed_sim_time / elapsed_wall_time
nanosec_per_day = ps_per_s / 1e3 * 60 * 60 * 24
sec_per_step = elapsed_wall_time / n_steps
n_atoms = system.positions.shape[0]
musec_per_step_per_atom = sec_per_step * 1e6 / n_atoms
log.info("simulation finished after: %.2f s", elapsed_wall_time)
log.info(
"performance summary: %.2f ns/day, %.2f mu s/step/atom",
nanosec_per_day,
musec_per_step_per_atom,
)
[docs]
def md_setup(model_config: Config, md_config: MDConfig):
"""
Sets up the energy and neighborlist functions for an MD simulation,
loads the initial structure.
Parameters
----------
model_config : Config
Configuration of the model used as an interatomic potential.
md_config : MDConfig
configuration of the MD simulation.
Returns
-------
R:
Initial positions in Angstrom.
atomic_numbers:
Atomic numbers of the system.
masses:
Atomic masses in ASE units.
box:
Side length of the cubic box.
energy_fn:
Interatomic potential.
neighbor_fn:
Neighborlist function.
shift_fn:
Shift function for the integrator.
"""
log.info("reading structure")
atoms = read(md_config.initial_structure)
system = System.from_atoms(atoms)
r_max = model_config.model.basis.r_max
log.info("initializing model")
if np.all(system.box < 1e-6):
frac_coords = False
displacement_fn, shift_fn = space.free()
else:
frac_coords = True
heights = heights_of_box_sides(system.box)
if np.any(atoms.cell.lengths() / 2 < r_max):
log.error(
f"Cutoff radius is larger than half the box in at least one cell vector direction: "
f"{r_max} > {np.min(atoms.cell.lengths()) / 2}. Cannot calculate correct neighbors."
)
if np.any(heights / 2 < r_max):
log.error(
f"Cutoff radius is larger than half the box in at least one cell vector direction: "
f"{r_max} > {np.min(heights) / 2}. Cannot calculate correct neighbors."
)
displacement_fn, shift_fn = space.periodic_general(
system.box,
fractional_coordinates=frac_coords,
wrapped=md_config.wrapped,
)
Builder = model_config.model.get_builder()
builder = Builder(model_config.model.model_dump())
energy_model = builder.build_energy_model(
apply_mask=True,
init_box=np.array(system.box),
inference_disp_fn=displacement_fn,
)
disable_cell_list = md_config.disable_cell_list or np.all(system.box < 1e-6)
neighbor_fn = partition.neighbor_list(
displacement_fn,
system.box,
r_max,
md_config.dr_threshold,
fractional_coordinates=frac_coords,
format=partition.Sparse,
disable_cell_list=disable_cell_list,
)
_, gradient_model_params = restore_parameters(model_config.data.model_version_path)
params = canonicalize_energy_model_parameters(gradient_model_params)
n_models = 1
shallow = False
if (
"ensemble" in model_config.model.model_dump().keys()
and model_config.model.ensemble is not None
and model_config.model.ensemble.n_members > 1
):
n_models = model_config.model.ensemble.n_members
if model_config.model.ensemble.kind == "shallow":
shallow = True
energy_fn = create_energy_fn(
energy_model.apply,
params,
system.atomic_numbers,
n_models,
shallow,
)
biases = []
if md_config.biases:
bias_list = [BiasEnergies(b.model_dump()) for b in md_config.biases]
biases.extend(bias_list)
auxiliary_fn = builder.build_energy_derivative_model(
apply_mask=True, init_box=np.array(system.box), inference_disp_fn=displacement_fn
).apply
if n_models > 1 and not shallow:
auxiliary_fn = maybe_vmap(auxiliary_fn, gradient_model_params)
auxiliary_fn = make_ensemble(auxiliary_fn)
else:
auxiliary_fn = partial(
auxiliary_fn,
gradient_model_params,
)
for bias in biases:
energy_fn = apply_bias_energy(bias, energy_fn)
auxiliary_fn = apply_bias_auxiliary(bias, auxiliary_fn)
sim_fns = SimulationFunctions(energy_fn, auxiliary_fn, shift_fn, neighbor_fn)
return system, sim_fns
[docs]
def run_md(model_config: Config, md_config: MDConfig, log_level="error"):
"""
Utiliy function to start NVT molecualr dynamics simulations from
a previously trained model.
Parameters
----------
model_config : Config
Configuration of the model used as an interatomic potential.
md_config : MDConfig
configuration of the MD simulation.
"""
model_config = parse_config(model_config)
md_config = parse_config(md_config, mode="md")
sim_dir = Path(md_config.sim_dir)
sim_dir.mkdir(parents=True, exist_ok=True)
log_file = sim_dir / "md.log"
setup_logging(log_file, log_level)
traj_path = sim_dir / md_config.traj_name
system, sim_fns = md_setup(model_config, md_config)
dynamics_checks = []
if md_config.dynamics_checks:
check_list = [
DynamicsChecks(check.model_dump()) for check in md_config.dynamics_checks
]
dynamics_checks.extend(check_list)
constraints = []
if md_config.constraints:
constraint_list = [Constraint(c.model_dump()) for c in md_config.constraints]
constraints.extend(constraint_list)
n_steps = int(np.ceil(md_config.duration / md_config.ensemble.dt))
traj_handler = H5TrajHandler(
system,
md_config.buffer_size,
traj_path,
md_config.ensemble.dt,
properties=md_config.properties,
h5md_options=md_config.h5md_options.model_dump(),
)
# TODO implement correct chunking
run_sim(
system,
sim_fns,
md_config.ensemble,
n_steps=n_steps,
n_inner=md_config.n_inner,
extra_capacity=md_config.extra_capacity,
sampling_rate=md_config.sampling_rate,
load_momenta=md_config.load_momenta,
traj_handler=traj_handler,
rng_key=jax.random.PRNGKey(md_config.seed),
restart=md_config.restart,
checkpoint_interval=md_config.checkpoint_interval,
sim_dir=sim_dir,
dynamics_checks=dynamics_checks,
constraints=constraints,
disable_pbar=md_config.disable_pbar,
)