[2]:
import warnings

warnings.simplefilter("ignore")

Molecular Dynamics

In this tutorial we will cover how to use trained models to drive MD simulations. For this purpose, apax offers two options: ASE and JaxMD. Keep in mind that JaxMD can be GPU/TPU accelerated and is therefore much faster. Both will be covered below.

Basic Model Training

First we need to train a model. If you have the parameters from tutorial 01, you can point the paths to those models and skip the current section to the ASE MD or the JaxMD section.

[3]:
!apax template train --full  # generating the config file in the cwd
Traceback (most recent call last):
  File "/home/tobiasdijkhuis/miniconda3/bin/apax", line 3, in <module>
    from apax.cli.apax_app import app
  File "/home/tobiasdijkhuis/PhD/apax_fork/apax/__init__.py", line 12, in <module>
    setup_ase()
  File "/home/tobiasdijkhuis/PhD/apax_fork/apax/utils/helpers.py", line 27, in setup_ase
    from ase.calculators.calculator import all_properties
  File "/home/tobiasdijkhuis/miniconda3/lib/python3.12/site-packages/ase/calculators/calculator.py", line 18, in <module>
    from ase.config import cfg as _cfg
  File "/home/tobiasdijkhuis/miniconda3/lib/python3.12/site-packages/ase/config.py", line 3, in <module>
    import configparser
  File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1331, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 935, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 995, in exec_module
  File "<frozen importlib._bootstrap_external>", line 1128, in get_code
  File "<frozen importlib._bootstrap_external>", line 757, in _compile_bytecode
KeyboardInterrupt
^C
[1]:
from pathlib import Path

import yaml

from apax.utils.datasets import download_etoh_ccsdt, mod_md_datasets
from apax.utils.helpers import mod_config

# Download and modify the dataset
data_path = Path("project")
experiment = "etoh_md"


train_file_path, test_file_path = download_etoh_ccsdt(data_path)
train_file_path = mod_md_datasets(train_file_path)
test_file_path = mod_md_datasets(test_file_path)


# Modify the config file (can be done manually)
config_path = Path("config_full.yaml")

config_updates = {
    "n_epochs": 100,
    "data": {
        "n_train": 990,
        "n_valid": 10,
        "valid_batch_size": 10,
        "experiment": experiment,
        "directory": "project/models",
        "data_path": str(train_file_path),
        "test_data_path": str(test_file_path),
        "energy_unit": "kcal/mol",
        "pos_unit": "Ang",
    },
}
config_dict = mod_config(config_path, config_updates)
del config_dict["transfer_learning"]

with open("config_full.yaml", "w") as conf:
    yaml.dump(config_dict, conf, default_flow_style=False)

# Train model
# run(config_dict)
/home/tobiasdijkhuis/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[1], line 37
     22 config_updates = {
     23     "n_epochs": 100,
     24     "data": {
   (...)     34     },
     35 }
     36 config_dict = mod_config(config_path, config_updates)
---> 37 del config_dict['transfer_learning']
     39 with open("config_full.yaml", "w") as conf:
     40     yaml.dump(config_dict, conf, default_flow_style=False)

KeyError: 'transfer_learning'

The ASE calculator

If you require some ASE features during your simulation, we provide an alternative to the JaxMD interface.

Please refer to the ASE documentation to see how to use ASE calculators.

An ASE calculator of a trained model can be instantiated as follows. Subsequend a ASE-MD is performed and OH-bondlength distribution is analysed.

[ ]:
from ase import units
from ase.io import read
from ase.io.trajectory import Trajectory
from ase.md.langevin import Langevin

from apax.md import ASECalculator

# read starting structure and define modelpath
atoms = read(train_file_path, index=0)
model_dir = data_path / f"models/{experiment}"


# initialize the apax ase calculator and assign it to the starting structure
calc = ASECalculator(model_dir=model_dir)
atoms.calc = calc

# perform MD simulation
dyn = Langevin(
    atoms=atoms,
    timestep=0.5 * units.fs,
    temperature_K=300,
    friction=0.001 / units.fs,
)

traj = Trajectory("example.traj", "w", atoms)
dyn.attach(traj.write, interval=1)
dyn.run(10000)
traj.close()
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
[4]:
import matplotlib.pyplot as plt
import numpy as np


def plot_bondlength_distribution(traj, indices: list, bins: int = 25):
    oh_dist = []
    for atoms in traj:
        oh_dist.append(atoms.get_distances(indices[0], indices[1]))

    fig, axs = plt.subplots()
    axs.hist(np.array(oh_dist), bins=25)
    fig.show()
[ ]:
# plot OH bondlength distribution of the MLMD simulation
traj = Trajectory("example.traj")
plot_bondlength_distribution(traj, indices=[2, -1])
/tmp/ipykernel_131697/524601467.py:12: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
  fig.show()
../_images/_tutorials_02_Molecular_dynamics_9_1.png

JaxMD

While the ASE interface is convenient and flexible, it is not meant for high performance applications. For these purposes, apax comes with an interface to JaxMD. JaxMD is a high performance molecular dynamics engine built on top of Jax. The CLI provides easy access to standard NVT and NPT simulations. More complex simulation loops are relatively easy to build yourself in JaxMD (see their colab notebooks for examples). Trained apax models can of course be used as energy_fn in such custom simulations. If you have a suggestion for adding some MD feature or thermostat to the core of apax, feel free to open up an issue on Github.

Configuration

We can once again use the template command to give ourselves a quickstart.

[ ]:
!apax template md

Open the config and specify the starting structure and simulation parameters. If you specify the data set file itself, the first structure of the data set is going to be used as the initial structure. Your md_config.yaml should look similar to this:

ensemble:
    temperature: 300 # K

duration: 20_000 # fs
initial_structure: project/benzene_mod.xyz

Full configuration file with descriptiond of the parameter can be found here.

[10]:
import yaml

from apax.utils.helpers import mod_config

md_config_path = Path("md_config.yaml")

config_updates = {
    "initial_structure": str(
        train_file_path
    ),  # if the model from example 01 is used change this
    "duration": 5000,  # fs
    "ensemble": {
        "temperature_schedule": {
            "T0": 300,
            "name": "constant",
        },
    },
}
config_dict = mod_config(md_config_path, config_updates)

with open(md_config_path, "w") as conf:
    yaml.dump(config_dict, conf, default_flow_style=False)

As with training configurations, we can use the validate command to ensure our input is valid before we submit the calculation.

[11]:
!apax validate md md_config.yaml
Success!
md_config.yaml is a valid MD config.

Running the simulation

The simulation can be started by running where config.yaml is the configuration file that was used to train the model.

[12]:
!apax md config_full.yaml md_config.yaml
INFO | 21:26:06 | reading structure
INFO:2026-02-11 21:26:06,775:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO | 21:26:06 | Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO | 21:26:06 | initializing model
INFO | 21:26:07 | loading checkpoint from /home/tobiasdijkhuis/PhD/apax_fork/examples/project/models/etoh_md/best
WARNING | 21:26:07 | `StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
INFO | 21:26:08 | Building Standard model
INFO | 21:26:08 | initializing simulation
INFO | 21:26:18 | running simulation for 5.0 ps
Simulation: 100%|█████████████████████████████████| 10000/10000 [00:31<00:00, 312.62it/s, T=269.7 K]
INFO | 21:26:51 | simulation finished after: 32.15 s
INFO | 21:26:51 | performance summary: 13.44 ns/day, 357.21 mu s/step/atom

During the simulation, a progress bar tracks the instantaneous temperature at each outer step. The simulation is followed by a small oh bondlength distribution analyses of the trajectory defined here.

[13]:
import znh5md

atoms = znh5md.IO("md/md.h5")[:]
plot_bondlength_distribution(atoms, indices=[2, -1])
/tmp/ipykernel_177348/524601467.py:12: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
  fig.show()
../_images/_tutorials_02_Molecular_dynamics_20_1.png

To remove all the created files and clean up yor working directory run

[ ]:
!rm -rf project md config_full.yaml example.traj md_config.yaml

OpenMM

While JaxMD is very fast, it is less mature than other dynamics codes, such as OpenMM. Apax also comes with an interface to OpenMM, which can be installed as

$ pip install apax[openmm]

This requires at least OpenMM 8.5.0.

[ ]:
from sys import stdout

from ase.io import read

# read starting structure and define modelpath
atoms = read(train_file_path, index=0)
model_dir = data_path / f"models/{experiment}"

from openmm.app import StateDataReporter
from openmm.openmm import LangevinIntegrator
from openmm.unit import femtosecond, kelvin, picosecond

from apax.md.openmm_interface import (
    create_simulation,
    create_system,
    get_PythonForce_from_Apax,
)
from apax.utils.openmm_reporters import XYZReporter

force = get_PythonForce_from_Apax(model_dir, atoms)

integrator = LangevinIntegrator(300 * kelvin, 1 / picosecond, 0.5 * femtosecond)
system = create_system(atoms)
system.addForce(force)

simulation = create_simulation(atoms, system, integrator)
simulation.context.setVelocitiesToTemperature(300 * kelvin, 1)

xyz_path = "openmm_trajectory.xyz"
md_steps = 10000

simulation.reporters.append(
    StateDataReporter(
        stdout,
        2000,
        time=True,
        progress=True,
        potentialEnergy=True,
        kineticEnergy=True,
        temperature=True,
        totalSteps=md_steps,
    )
)
simulation.reporters.append(
    XYZReporter(
        xyz_path,
        10,
        atoms.symbols,
        enforcePeriodicBox=False,
        includeForces=False,
        includeVelocities=False,
        flushEvery=10,
    )
)

simulation.step(md_steps)
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
#"Progress (%)","Time (ps)","Potential Energy (kJ/mole)","Kinetic Energy (kJ/mole)","Temperature (K)"
20.0%,0.9999999999999453,-406235.97975756164,34.049428281456336,341.2670772724295
40.0%,1.9999999999998352,-406235.5186811347,25.399058516867196,254.56704864096886
60.0%,3.000000000000169,-406217.4295529422,29.79128798077592,298.58903043366485
80.0%,4.000000000000503,-406233.99682036287,26.633690637892766,266.9413913748973
100.0%,4.9999999999999485,-406239.5358689742,28.93170449026708,289.97368620371947
[8]:
traj = read(xyz_path, index=":")
plot_bondlength_distribution(traj, indices=[2, -1])
/tmp/ipykernel_289686/524601467.py:12: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
  fig.show()
../_images/_tutorials_02_Molecular_dynamics_26_1.png
[ ]: