Transfer Learning¶
Datasets computed at high levels of theory are expensive and thus, usually small. A model trained on this data might not be able to generalize well to unseen configurations. Sometimes this can be remedied with transfer learning: By first training a model on a lot of data from a less expensive level of theory, only small adjustments to the parameters are required to accurately reproduce the potential energy surface of a different level of theory.
Alternatively, the level of theory might not change, but the dataset is extended. This is the case in learning on the fly scenarios. For a demonstration of using transfer learning for learning on the fly, see the corresponding example from the IPSuite documentation.
Apax comes with discriminative transfer learning capabilities out of the box. In this tutorial we are going to fine tune a model trained on benzene data at the DFT level of theory to CCSDT.
First download the appropriate dataset from the sgdml website.
Transfer learning can be facilitated in apax by adding the path to a pre-trained model in the config. Furthermore, we can freeze or reduce the learning rate of various components by adjusting the optimizer section of the config.
optimizer:
nn_lr: 0.004
embedding_lr: 0.0
Learning rates of 0.0 will mask the respective weights during training steps. Here, we will freeze the descriptor, reinitialize the scaling and shifting parameters and reduce the learning rate of all other components.
We can now fine tune the model by running apax train config.yaml
[1]:
from pathlib import Path
import yaml
from apax.utils.datasets import (
download_benzene_DFT,
download_md22_benzene_CCSDT,
mod_md_datasets,
)
from apax.utils.helpers import mod_config
Acquire Datasets¶
For this demonstration we will use the DFT and CC versions of the benzene MD17 dataset. We start by downloading both and saving them in an appropriate format.
[2]:
# Download DFT Data
data_path = Path("project")
dft_file_path = download_benzene_DFT(data_path)
dft_file_path = mod_md_datasets(dft_file_path)
[3]:
# Download CCSD(T) Data
data_path = Path("project")
cc_file_path, _ = download_md22_benzene_CCSDT(data_path)
cc_file_path = mod_md_datasets(cc_file_path)
Pretrain Model¶
First, we will train a model on the “large” (in relative terms) but less accurate DFT dataset. A standard model with default optimizers will do just fine.
[4]:
!apax template train --full
[5]:
config_path = Path("config_full.yaml")
config_updates = {
"n_epochs": 100,
"data": {
"n_train": 1000,
"n_valid": 200,
"batch_size": 8,
"valid_batch_size": 100,
"experiment": "benzene_dft",
"directory": "project/models",
"data_path": str(dft_file_path),
"energy_unit": "kcal/mol",
"pos_unit": "Ang",
},
}
config_dict = mod_config(config_path, config_updates)
with open("config_full.yaml", "w") as conf:
yaml.dump(config_dict, conf, default_flow_style=False)
[6]:
!apax train config_full.yaml
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732268582.750898 524269 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732268582.754067 524269 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO | 09:43:04 | Running on [CudaDevice(id=0)]
INFO | 09:43:04 | Initializing Callbacks
INFO | 09:43:04 | Initializing Loss Function
INFO | 09:43:04 | Initializing Metrics
INFO | 09:43:04 | Running Input Pipeline
INFO | 09:43:04 | Reading data file project/benzene_mod.xyz
INFO | 09:43:11 | Found n_train: 1000, n_val: 200
INFO | 09:43:11 | Computing per element energy regression.
INFO | 09:43:12 | Building Standard model
INFO | 09:43:12 | initializing 1 model(s)
INFO | 09:43:18 | Initializing Optimizer
INFO | 09:43:18 | Beginning Training
Epochs: 0%| | 0/100 [00:00<?, ?it/s]WARNING | 09:43:27 | SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before August 1st, 2024.
Epochs: 100%|████████████████████████████████████| 100/100 [00:42<00:00, 2.36it/s, val_loss=0.0233]
INFO | 09:44:01 | Finished training
Baseline CC Training¶
Next, we require a CC baseline to quantify the effect of pretraining. As with the DFT dataset, we will only use a small fraction of the data to emphasize the effects in the low-data regime.
[7]:
config_path = Path("config_full.yaml")
config_updates = {
"n_epochs": 100,
"data": {
"n_train": 50,
"n_valid": 10,
"batch_size": 4,
"valid_batch_size": 10,
"experiment": "benzene_cc_baseline",
"directory": "project/models",
"data_path": str(cc_file_path),
"energy_unit": "kcal/mol",
"pos_unit": "Ang",
},
}
config_dict = mod_config(config_path, config_updates)
with open("config_cc_baseline.yaml", "w") as conf:
yaml.dump(config_dict, conf, default_flow_style=False)
[8]:
!apax train config_cc_baseline.yaml
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732268643.129714 525084 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732268643.132821 525084 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO | 09:44:05 | Running on [CudaDevice(id=0)]
INFO | 09:44:05 | Initializing Callbacks
INFO | 09:44:05 | Initializing Loss Function
INFO | 09:44:05 | Initializing Metrics
INFO | 09:44:05 | Running Input Pipeline
INFO | 09:44:05 | Reading data file project/benzene_ccsd_t-train_mod.xyz
INFO | 09:44:06 | Found n_train: 50, n_val: 10
INFO | 09:44:06 | Computing per element energy regression.
INFO | 09:44:06 | Building Standard model
INFO | 09:44:06 | initializing 1 model(s)
INFO | 09:44:13 | Initializing Optimizer
INFO | 09:44:13 | Beginning Training
Epochs: 0%| | 0/100 [00:00<?, ?it/s]WARNING | 09:44:20 | SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before August 1st, 2024.
Epochs: 100%|████████████████████████████████████| 100/100 [00:14<00:00, 6.96it/s, val_loss=0.0766]
INFO | 09:44:27 | Finished training
DFT -> CC Fine Tuning¶
Finally, we can fine tune a model that was pretrained on DFT data. The model architecture remains unchanged for all 3 runs. However, for fine-tuning we need to specify the path to the base model and how to deal with its parameters. For each parameter group we can choose to freeze, to reset it or to keep training it. It is certainly advisable to experiment with different strategies, but a good start consists in freezing the embedding layer if the system we transfer to remains the same and resetting the scale-shift layer if the level of theory changes (DFT and CC have different energy scales).
Make sure to carefully inspect the config options below.
[9]:
config_path = Path("config_full.yaml")
config_updates = {
"n_epochs": 100,
"data": {
"n_train": 50,
"n_valid": 10,
"batch_size": 4,
"valid_batch_size": 10,
"experiment": "benzene_cc_ft",
"directory": "project/models",
"data_path": str(cc_file_path),
"energy_unit": "kcal/mol",
"pos_unit": "Ang",
},
"optimizer": {
"emb_lr": 0.00, # freeze embedding layer
"nn_lr": 0.0005, # lower lr
"scale_lr": 0.001, # lower lr
"shift_lr": 0.005, # lower lr
},
"checkpoints": {
"base_model_checkpoint": "project/models/benzene_dft", # pretrained model
"reset_layers": ["scale_shift"], # reset scale-shift layer
},
}
config_dict = mod_config(config_path, config_updates)
with open("config_cc_ft.yaml", "w") as conf:
yaml.dump(config_dict, conf, default_flow_style=False)
[10]:
!apax train config_cc_ft.yaml
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732268669.330941 525885 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732268669.334048 525885 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO | 09:44:31 | Running on [CudaDevice(id=0)]
INFO | 09:44:31 | Initializing Callbacks
INFO | 09:44:31 | Initializing Loss Function
INFO | 09:44:31 | Initializing Metrics
INFO | 09:44:31 | Running Input Pipeline
INFO | 09:44:31 | Reading data file project/benzene_ccsd_t-train_mod.xyz
INFO | 09:44:31 | Found n_train: 50, n_val: 10
INFO | 09:44:31 | Computing per element energy regression.
INFO | 09:44:31 | Building Standard model
INFO | 09:44:31 | initializing 1 model(s)
INFO | 09:44:38 | Initializing Optimizer
INFO | 09:44:38 | loading checkpoint from project/models/benzene_dft/best
INFO | 09:44:38 | Transferring parameters from project/models/benzene_dft
INFO | 09:44:38 | Transferring parameter: radial_fn
INFO | 09:44:38 | Transferring parameter: dense_0
INFO | 09:44:38 | Transferring parameter: dense_0
INFO | 09:44:38 | Transferring parameter: dense_1
INFO | 09:44:38 | Transferring parameter: dense_1
INFO | 09:44:38 | Transferring parameter: dense_2
INFO | 09:44:38 | Transferring parameter: dense_2
INFO | 09:44:38 | Beginning Training
Epochs: 0%| | 0/100 [00:00<?, ?it/s]WARNING | 09:44:44 | SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before August 1st, 2024.
Epochs: 100%|███████████████████████████████████| 100/100 [00:11<00:00, 8.62it/s, val_loss=0.00752]
INFO | 09:44:50 | Finished training
As we can see, the fine-tuned model achieves a lower validation loss than the baseline CC model.
How much further can you improve the fine-tuning (or pretraining) setup?
[12]:
!rm -rf project config_full.yaml config_cc_baseline.yaml config_cc_ft.yaml
[ ]: