Batch Active Learning¶
While it is possible to perform rudimentary data selection simply by randomly choosing samples, the batch of data thus drawn might not be the most informative one. Choosing those samples with the largest prediction uncertainties from trajectories often results in the selection of configurations from subsequent time steps.
Batch selection methods can be constructed to select informative and diverse data, with or without following the underlying distribution.
We will illustrate this in a mock learning on the fly setup.
[1]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import yaml
from ase import units
from ase.io import read
from ase.io.trajectory import TrajectoryWriter
from ase.md.langevin import Langevin
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.optimize.fire import FIRE
from apax.bal import api
from apax.md import ASECalculator
from apax.utils.datasets import download_md22_benzene_CCSDT, mod_md_datasets
from apax.utils.helpers import mod_config
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732268812.515731 526715 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732268812.518770 526715 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Dataset Acquisition¶
[2]:
# Download CCSD(T) Data
data_path = Path("project")
cc_train_file_path, cc_val_file_path = download_md22_benzene_CCSDT(data_path)
cc_train_file_path = mod_md_datasets(cc_train_file_path)
cc_val_file_path = mod_md_datasets(cc_val_file_path)
Model Training¶
Unlike simpler data selection methods, such as random selection, we first need to train a model. It is the representation learned by the model which will serve as the basis for our similarity metric.
[3]:
!apax template train --full
[4]:
config_path = Path("config_full.yaml")
config_updates = {
"n_epochs": 200,
"data": {
"batch_size": 4,
"valid_batch_size": 100,
"experiment": "benzene",
"directory": "project/models",
"train_data_path": str(cc_train_file_path),
"val_data_path": str(cc_val_file_path),
"data_path": None,
"energy_unit": "kcal/mol",
"pos_unit": "Ang",
},
}
config_dict = mod_config(config_path, config_updates)
with open("config_full.yaml", "w") as conf:
yaml.dump(config_dict, conf, default_flow_style=False)
[5]:
!apax train config_full.yaml
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732268817.378390 526800 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732268817.381573 526800 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO | 09:46:59 | Running on [CudaDevice(id=0)]
INFO | 09:46:59 | Initializing Callbacks
INFO | 09:46:59 | Initializing Loss Function
INFO | 09:46:59 | Initializing Metrics
INFO | 09:46:59 | Running Input Pipeline
INFO | 09:46:59 | Reading training data file project/benzene_ccsd_t-train_mod.xyz
INFO | 09:46:59 | Reading validation data file project/benzene_ccsd_t-test_mod.xyz
INFO | 09:47:00 | Found n_train: 1000, n_val: 500
INFO | 09:47:00 | Computing per element energy regression.
INFO | 09:47:00 | Building Standard model
INFO | 09:47:01 | initializing 1 model(s)
INFO | 09:47:07 | Initializing Optimizer
INFO | 09:47:07 | Beginning Training
Epochs: 0%| | 0/200 [00:00<?, ?it/s]WARNING | 09:47:16 | SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before August 1st, 2024.
Epochs: 100%|███████████████████████████████████| 200/200 [01:56<00:00, 1.72it/s, val_loss=0.00414]
INFO | 09:49:04 | Finished training
Molecular Dynamics / Data Generation¶
Now we will create a pool of data to select new samples from. In order to emphasize the selection method used in the next section, we will combine MD and geometry optimization, a common occurrence in learning on the fly.
[6]:
atoms = read(str(cc_train_file_path), "0")
calc = ASECalculator("project/models/benzene")
atoms.calc = calc
[7]:
def printenergy(a=atoms): # store a reference to atoms in the definition.
"""Function to print the potential, kinetic and total energy."""
epot = a.get_potential_energy() / len(a)
ekin = a.get_kinetic_energy() / len(a)
print(
"Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) "
"Etot = %.3feV" % (epot, ekin, ekin / (1.5 * units.kB), epot + ekin)
)
writer = TrajectoryWriter("project/benzene.traj", "w", atoms)
MaxwellBoltzmannDistribution(atoms, temperature_K=298)
dyn = Langevin(atoms, 0.5 * units.fs, temperature_K=298, friction=0.002)
dyn.attach(writer, interval=100)
dyn.attach(printenergy, interval=1000)
dyn.run(10000)
opt = FIRE(atoms)
opt.attach(writer)
opt.run()
Energy per atom: Epot = -524.782eV Ekin = 0.034eV (T=266K) Etot = -524.748eV
Energy per atom: Epot = -524.786eV Ekin = 0.047eV (T=367K) Etot = -524.739eV
Energy per atom: Epot = -524.789eV Ekin = 0.044eV (T=344K) Etot = -524.745eV
Energy per atom: Epot = -524.798eV Ekin = 0.050eV (T=389K) Etot = -524.748eV
Energy per atom: Epot = -524.779eV Ekin = 0.033eV (T=259K) Etot = -524.745eV
Energy per atom: Epot = -524.778eV Ekin = 0.029eV (T=226K) Etot = -524.748eV
Energy per atom: Epot = -524.796eV Ekin = 0.052eV (T=400K) Etot = -524.745eV
Energy per atom: Epot = -524.778eV Ekin = 0.030eV (T=230K) Etot = -524.748eV
Energy per atom: Epot = -524.783eV Ekin = 0.031eV (T=242K) Etot = -524.751eV
Energy per atom: Epot = -524.799eV Ekin = 0.037eV (T=286K) Etot = -524.762eV
Energy per atom: Epot = -524.794eV Ekin = 0.039eV (T=301K) Etot = -524.755eV
Step Time Energy fmax
FIRE: 0 09:49:17 -6297.529843 3.433130
FIRE: 1 09:49:17 -6297.707876 0.822871
FIRE: 2 09:49:17 -6297.653997 2.416255
FIRE: 3 09:49:17 -6297.701468 1.822036
FIRE: 4 09:49:17 -6297.755201 0.799443
FIRE: 5 09:49:17 -6297.771731 0.639807
FIRE: 6 09:49:17 -6297.770975 0.635587
FIRE: 7 09:49:17 -6297.774358 0.628206
FIRE: 8 09:49:17 -6297.776426 0.617073
FIRE: 9 09:49:17 -6297.779687 0.603059
FIRE: 10 09:49:17 -6297.782312 0.586707
FIRE: 11 09:49:17 -6297.786402 0.568710
FIRE: 12 09:49:17 -6297.790373 0.550295
FIRE: 13 09:49:17 -6297.793097 0.528365
FIRE: 14 09:49:17 -6297.797588 0.504850
FIRE: 15 09:49:17 -6297.802828 0.477577
FIRE: 16 09:49:17 -6297.810041 0.446206
FIRE: 17 09:49:17 -6297.816479 0.407242
FIRE: 18 09:49:17 -6297.821895 0.358514
FIRE: 19 09:49:17 -6297.828423 0.300577
FIRE: 20 09:49:17 -6297.835409 0.244853
FIRE: 21 09:49:17 -6297.841073 0.207934
FIRE: 22 09:49:17 -6297.844864 0.208605
FIRE: 23 09:49:17 -6297.848110 0.154047
FIRE: 24 09:49:17 -6297.849490 0.151302
FIRE: 25 09:49:17 -6297.848904 0.132234
FIRE: 26 09:49:17 -6297.851377 0.131325
FIRE: 27 09:49:17 -6297.852449 0.132891
FIRE: 28 09:49:17 -6297.851385 0.134429
FIRE: 29 09:49:17 -6297.850981 0.132706
FIRE: 30 09:49:17 -6297.853413 0.127714
FIRE: 31 09:49:17 -6297.853684 0.119362
FIRE: 32 09:49:17 -6297.853303 0.107520
FIRE: 33 09:49:17 -6297.854404 0.095077
FIRE: 34 09:49:17 -6297.853003 0.086662
FIRE: 35 09:49:17 -6297.856393 0.082878
FIRE: 36 09:49:17 -6297.855421 0.077966
FIRE: 37 09:49:17 -6297.855617 0.070408
FIRE: 38 09:49:17 -6297.856790 0.064497
FIRE: 39 09:49:17 -6297.858416 0.063379
FIRE: 40 09:49:17 -6297.857967 0.061494
FIRE: 41 09:49:17 -6297.857515 0.062909
FIRE: 42 09:49:17 -6297.858565 0.051126
FIRE: 43 09:49:17 -6297.859075 0.076019
FIRE: 44 09:49:17 -6297.857729 0.045322
[7]:
True
Selecting New Datapoints¶
Now it is time to select new data points from our pool. In the following we choose the last-layer gradient kernel as the similarity metric and the max dist selection method (farthest point sampling) to select 10 datapoints from out pool.
[8]:
train_atoms = read(str(cc_train_file_path), ":")
pool_atoms = read("project/benzene.traj", ":")
[9]:
len(pool_atoms)
[9]:
146
[10]:
base_fm_options = {"name": "ll_grad", "layer_name": "dense_2"}
selection_method = "max_dist"
bs = 10
selected_indices = api.kernel_selection(
"project/models/benzene",
train_atoms,
pool_atoms,
base_fm_options,
selection_method,
selection_batch_size=bs,
processing_batch_size=bs,
)
Computing features: 100%|██████████████████████████████████████| 1146/1146 [00:05<00:00, 224.35it/s]
[12]:
energies = np.array([a.get_potential_energy() for a in pool_atoms])
# selected_indices = np.random.randint(0, len(energies), 10)
selection_energies = energies[selected_indices[0]]
As we can see below, the batch selection method only picks a few data points from the optimization part of the pool, indicating that during an optimization the structure of the molecule does not change very much. Hence, there are not many informative samples to be found in it.
[13]:
fig, ax = plt.subplots()
ax.plot(energies)
ax.scatter(
selected_indices[0], selection_energies, marker="x", color="red", label="selection"
)
ax.set_ylabel("Energy / eV")
ax.set_xlabel("Image")
ax.legend()
[13]:
<matplotlib.legend.Legend at 0x7fda74fda790>
[14]:
!rm -rf project config_full.yaml
[ ]: