[1]:
import warnings

warnings.simplefilter("ignore")

Model Training

In this tutorial we are going to train a model from scratch on a molecular dataset from the MD17 collection. Start by creating a project folder and downloading the dataset.

Acquiring a dataset

You can obtain the benzene dataset with DFT labels either by running the following command or manually from this link. Apax uses ASE to read in datasets, so make sure to convert your own data into an ASE readable format (extxyz, traj etc). Be careful the downloaded dataset has to be modified like in the apax.utils.dataset.mod_md_datasets function in order to be readable.

[2]:
from pathlib import Path

from apax.utils.datasets import download_etoh_ccsdt, mod_md_datasets

data_path = Path("project")

train_file_path, test_file_path = download_etoh_ccsdt(data_path)
train_file_path = mod_md_datasets(train_file_path)
test_file_path = mod_md_datasets(test_file_path)

Configuration files

Next, we require a configuration file that specifies the model and training parameters. In order to get users quickly up and running, our command line interface provides an easy way to generate input templates. The provided templates come in in two levels of verbosity: minimal and full. In the following we are going to use a minimal input file. To see a complete list and explanation of all parameters, consult the documentation page LINK. For more information on the CLI, simply run apax -h.

[3]:
!apax -h
                                                                                
 Usage: apax [OPTIONS] COMMAND [ARGS]...                                        
                                                                                
╭─ Options ────────────────────────────────────────────────────────────────────╮
│ --version             -V                                                     │
│ --install-completion            Install completion for the current shell.    │
│ --show-completion               Show completion for the current shell, to    │
│                                 copy it or customize the installation.       │
│ --help                -h        Show this message and exit.                  │
╰──────────────────────────────────────────────────────────────────────────────╯
╭─ Commands ───────────────────────────────────────────────────────────────────╮
│ docs       Opens the documentation website in your browser.                  │
│ eval       Starts performing the evaluation of the test dataset with         │
│            parameters provided by a configuration file.                      │
│ md         Starts performing a molecular dynamics simulation (currently only │
│            NHC thermostat) with parameters provided by a configuration file. │
│ schema     Generating JSON schemata for autocompletion of train/md inputs in │
│            VSCode.                                                           │
│ template   Create configuration file templates.                              │
│ train      Starts the training of a model with parameters provided by a      │
│            configuration file.                                               │
│ validate   Validate training or MD config files.                             │
│ visualize  Visualize a model based on a configuration file. A CO molecule is │
│            taken as sample input (influences number of atoms, number of      │
│            species is set to 10).                                            │
╰──────────────────────────────────────────────────────────────────────────────╯

The following command create a minimal configuration file in the working directory. Full configuration file with descriptiond of the parameter can be found here.

[4]:
!apax template train --full

Open the resulting config.yaml file in an editor of your choice and make sure to fill in the data path field with the name of the data set you just downloaded. For the purposes of this tutorial we will train on 1000 data points and validate the model on 200 more during the training. Further, the units of the labels have to be specified. Random splitting is done by apax but it is also possible to input a pre-splitted training and validation dataset.

In order to check whether the a configuration file is valid, we provide the validate command. This is especially convenient when submitting training runs on a compute cluster.

[5]:
!apax validate train config_full.yaml
1 validation errors for config
n_epochs
  Input should be a valid integer, unable to parse string as an integer
  input_type: str
  input: <NUMBER OF EPOCHS>

Configuration Invalid!

Configuration files are validated using Pydantic and the errors provided by the validate command give precise instructions on how to fix the input file. The filled in configuration file should look similar to this one.

data:
  batch_size: 4
  data_path: project/ethanol_ccsd_t-train_mod.xyz
  directory: project/models
  energy_unit: kcal/mol
  experiment: ethanol_ccsd_t_cli
  n_train: 990
  n_valid: 10
  energy_unit: kcal/mol
  pos_unit: Ang
  valid_batch_size: 100
loss:
- name: energy
- name: forces
  weight: 4.0
metrics:
- name: energy
  reductions:
  - mae
- name: forces
  reductions:
  - mae
  - mse
model:
  descriptor_dtype: fp64
n_epochs: 100

It also can be modefied with the utils function mod_config provided by Apax.

[6]:
import yaml

from apax.utils.helpers import mod_config

config_path = Path("config_full.yaml")

config_updates = {
    "n_epochs": 100,
    "data": {
        "n_train": 990,
        "n_valid": 10,
        "valid_batch_size": 10,
        "experiment": "ethanol_ccsd_t_cli",
        "directory": "project/models",
        "data_path": str(train_file_path),
        "test_data_path": str(test_file_path),
        "energy_unit": "kcal/mol",
        "pos_unit": "Ang",
    },
}

config_dict = mod_config(config_path, config_updates)

with open("config_full.yaml", "w") as conf:
    yaml.dump(config_dict, conf, default_flow_style=False)
[7]:
!apax validate train config_full.yaml
Success!
config_full.yaml is a valid training config.

Training

Model training can be started by running

[8]:
!apax train config_full.yaml
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732268187.845256  520474 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732268187.848463  520474 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO | 09:36:31 | Running on [CudaDevice(id=0)]
INFO | 09:36:31 | Initializing Callbacks
INFO | 09:36:32 | Initializing Loss Function
INFO | 09:36:32 | Initializing Metrics
INFO | 09:36:32 | Running Input Pipeline
INFO | 09:36:32 | Reading data file project/ethanol_ccsd_t-train_mod.xyz
INFO | 09:36:32 | Found n_train: 990, n_val: 10
INFO | 09:36:32 | Computing per element energy regression.
INFO | 09:36:33 | Building Standard model
INFO | 09:36:33 | initializing 1 model(s)
INFO | 09:36:40 | Initializing Optimizer
INFO | 09:36:40 | Beginning Training
Epochs:   0%|                                                               | 0/100 [00:00<?, ?it/s]WARNING | 09:36:47 | SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before August 1st, 2024.
Epochs: 100%|████████████████████████████████████| 100/100 [00:59<00:00,  1.69it/s, val_loss=0.0287]
INFO | 09:37:39 | Finished training

During training, apax displays a progress bar to keep track of the validation loss. This progress bar is optional however and can be turned off in the config. The default configuration writes training metrics to a CSV file, but TensorBoard is also supported. One can specify which to use by adding the following section to the input file:

callbacks:
    - CSV

If training is interrupted for any reason, re-running the above train command will resume training from the latest checkpoint.

Furthermore, an Apax training can easily be started within a script.

[9]:
from apax.train.run import run
from apax.utils.helpers import mod_config

config_path = Path("config_full.yaml")

config_updates = {
    "data": {
        "experiment": "ethanol_ccsd_t_script",
    },
}

config_dict = mod_config(config_path, config_updates)

run(config_dict)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732268260.595895  520335 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732268260.599126  520335 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Epochs:   0%|                                                               | 0/100 [00:00<?, ?it/s]WARNING | 09:37:58 | SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before August 1st, 2024.
Epochs: 100%|████████████████████████████████████| 100/100 [01:01<00:00,  1.62it/s, val_loss=0.0287]
[10]:
import matplotlib.pyplot as plt
import numpy as np

from apax.utils.helpers import load_csv_metrics

metrics_path = "project/models/ethanol_ccsd_t_script/log.csv"
keys = ["energy_mae", "forces_mse", "forces_mae", "loss"]

data_dict = load_csv_metrics(metrics_path)

fig, axes = plt.subplots(2, 2, constrained_layout=True)
axes = axes.ravel()
fig.suptitle("Metrics", fontsize=16)

for id, key in enumerate(keys):
    val = np.array(data_dict[f"val_{key}"])
    train = np.array(data_dict[f"train_{key}"])
    epoch = np.array(data_dict["epoch"])

    axes[id].plot(epoch, val, label="val data")
    axes[id].plot(epoch, train, label="train data")

    axes[id].set_ylabel(f"{key}")
    axes[id].set_xlabel(r"epoch")

plt.legend()
plt.show()
../_images/_tutorials_01_Model_Training_16_0.png

Evaluation

After the training is completed and we are satisfied with our choice of hyperparameters and vadliation loss, we can evaluate the model on the test set. We provide a separate command for test set evaluation:

[11]:
from apax.train.eval import eval_model

eval_model(config_dict)
Structure: 100%|███████████████████████████████| 999/999 [00:04<00:00, 228.47it/s, test_loss=0.0253]
[12]:
!apax eval config_full.yaml
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732268339.519757  522195 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732268339.522952  522195 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Structure: 100%|███████████████████████████████| 999/999 [00:04<00:00, 229.06it/s, test_loss=0.0253]
[13]:
metrics_path = "project/models/ethanol_ccsd_t_script/eval/log.csv"
keys = ["energy_mae", "forces_mse", "forces_mae", "loss"]

data_dict = load_csv_metrics(metrics_path)

fig, axes = plt.subplots(1, 4, constrained_layout=True)
axes = axes.ravel()
fig.suptitle("Metrics", fontsize=16)

for id, key in enumerate(keys):
    test = np.array(data_dict[f"test_{key}"])

    axes[id].set_title(f"{key}")
    axes[id].boxplot(test)
plt.show()
../_images/_tutorials_01_Model_Training_20_0.png

Congratulations, you have successfully trained and evaluated your first apax model!

A Closer Look At Training Parameters

To remove all the created files and clean up yor working directory run

[14]:
# !rm -rf project config_full.yaml eval.log

[ ]: