Source code for apax.md.ase_calc

from functools import partial
from pathlib import Path
from typing import Callable, Union

import ase
import jax
import jax.numpy as jnp
import numpy as np
from ase.calculators.calculator import Calculator, all_changes, all_properties
from ase.calculators.singlepoint import SinglePointCalculator
from flax.core.frozen_dict import freeze, unfreeze
from jax import tree_util
from tqdm import tqdm, trange
from vesin import NeighborList

from apax.config.train_config import Config
from apax.data.input_pipeline import CachedInMemoryDataset, OTFInMemoryDataset
from apax.md.function_transformations import ProcessStress
from apax.train.checkpoints import (
    canonicalize_energy_model_parameters,
    check_for_ensemble,
    restore_parameters,
)
from apax.utils.jax_md_reduced import partition, space


def maybe_vmap(apply, params):
    n_models = check_for_ensemble(params)

    if n_models > 1:
        apply = jax.vmap(apply, in_axes=(0, None, None, None, None, None))

    # Maybe the partial mapping should happen at the very end of initialize
    # That way other functions can mae use of the parameter shape information
    energy_fn = partial(apply, params)
    return energy_fn


def build_energy_neighbor_fns(
    atoms: ase.Atoms, config: Config, params, dr_threshold: float, neigbor_from_jax: bool
):
    r_max = config.model.basis.r_max
    box = jnp.asarray(atoms.cell.array, dtype=jnp.float64)
    box = box.T
    displacement_fn = None
    neighbor_fn = None

    if neigbor_from_jax:
        if np.all(box < 1e-6):
            displacement_fn, _ = space.free()
        else:
            displacement_fn, _ = space.periodic_general(box, fractional_coordinates=True)

        neighbor_fn = partition.neighbor_list(
            displacement_fn,
            box,
            config.model.basis.r_max,
            dr_threshold,
            fractional_coordinates=True,
            disable_cell_list=True,
            format=partition.Sparse,
        )

    n_species = 119  # int(np.max(Z) + 1)
    Builder = config.model.get_builder()
    builder = Builder(config.model.model_dump(), n_species=n_species)

    model = builder.build_energy_derivative_model(
        apply_mask=True, init_box=np.array(box), inference_disp_fn=displacement_fn
    )

    energy_fn = maybe_vmap(model.apply, params)
    return energy_fn, neighbor_fn


def make_ensemble(model):
    def ensemble(positions, Z, idx, box, offsets):
        results = model(positions, Z, idx, box, offsets)
        uncertainty = {k + "_uncertainty": jnp.std(v, axis=0) for k, v in results.items()}
        ensemble = {k + "_ensemble": v for k, v in results.items()}
        results = {k: jnp.mean(v, axis=0) for k, v in results.items()}
        if "forces_ensemble" in ensemble.keys():
            ensemble["forces_ensemble"] = jnp.transpose(
                ensemble["forces_ensemble"], (1, 2, 0)
            )
        if "forces_ensemble" in ensemble.keys():
            ensemble["stress_ensemble"] = jnp.transpose(
                ensemble["forces_ensemble"], (1, 2, 0)
            )
        results.update(uncertainty)
        results.update(ensemble)

        return results

    return ensemble


def unpack_results(results, inputs):
    n_structures = len(results["energy"])
    unpacked_results = []
    for i in range(n_structures):
        single_results = tree_util.tree_map(lambda x: x[i], results)
        single_results["energy"] = single_results["energy"].item()
        for k, v in single_results.items():
            if "forces" in k:
                single_results[k] = v[: inputs["n_atoms"][i]]
        unpacked_results.append(single_results)
    return unpacked_results


[docs] class ASECalculator(Calculator): """ ASE Calculator for apax models. Always implements energy and force predictions. Stress predictions and corresponding uncertainties are added to `implemented_properties` based on whether the stress flag is set in the model config and whether a model ensemble is loaded. """ implemented_properties = [ "energy", "forces", ] def __init__( self, model_dir: Union[Path, list[Path]], dr_threshold: float = 0.5, transformations: list[Callable] = [], padding_factor: float = 1.5, **kwargs, ): """ Parameters ---------- model_dir: Path to a model directory of the form `.../directory/experiment` (see Config docs for details). If a list of model paths is provided, they will be ensembled. dr_threshold: Neighborlist skin for the JaxMD neighborlist. transformations: Function transformations applied on top of the EnergyDerivativeModel. Transfomrations are implemented under `apax.md.transformations`. padding_factor: Multiple of the fallback vesin's amount of neighbors. This NL will be padded to `len(neighbors) * padding_factor` on NL initialization. """ Calculator.__init__(self, **kwargs) self.dr_threshold = dr_threshold self.transformations = transformations self.model_config, self.params = restore_parameters(model_dir) for head in self.model_config.model.property_heads: self.implemented_properties.append(head.name) self.n_models = check_for_ensemble(self.params) self.padding_factor = padding_factor self.padded_length = 0 self._update_implemented_properties() self.step = None self.neighbor_fn = None self.neighbors = None self.offsets = None self.model = None def _update_implemented_properties(self): """ Dynamically determines the properties this model produces based on its config, including custom heads, stress, and ensemble uncertainties. Updates ASE's global all_properties to ensure compatibility. """ props = {"energy", "forces"} if self.model_config.model.calc_stress: props.add("stress") for head in self.model_config.model.property_heads: props.add(head.name) # Handle ensembles ensemble_config = self.model_config.model.ensemble if ensemble_config: ensembled_props = set() if ensemble_config.kind == "full": # For a full ensemble, all base properties are ensembled ensembled_props.update(props) elif ensemble_config.kind == "shallow": # For a shallow ensemble, only energy and potentially forces are ensembled ensembled_props.add("energy") if ensemble_config.force_variance: ensembled_props.add("forces") for p in ensembled_props: props.add(f"{p}_uncertainty") props.add(f"{p}_ensemble") # Handle individual property head ensembles for head in self.model_config.model.property_heads: if head.n_shallow_members > 0: props.add(f"{head.name}_uncertainty") props.add(f"{head.name}_ensemble") # Finalize and update global list self.implemented_properties = sorted(props) for p in self.implemented_properties: if p not in all_properties: all_properties.append(p) def initialize(self, atoms): box = jnp.asarray(atoms.cell.array, dtype=jnp.float64) self.r_max = self.model_config.model.basis.r_max self.neigbor_from_jax = neighbor_calculable_with_jax(box, self.r_max) model, neighbor_fn = build_energy_neighbor_fns( atoms, self.model_config, self.params, self.dr_threshold, self.neigbor_from_jax, ) if self.n_models > 1: model = make_ensemble(model) if "stress" in self.implemented_properties: model = ProcessStress().apply(model) for transformation in self.transformations: model = transformation.apply(model) self.model = model self.step = get_step_fn( model, jnp.asarray(atoms.numbers), bool(np.any(atoms.cell.array > 1e-6)), self.neigbor_from_jax, ) self.neighbor_fn = neighbor_fn if self.neigbor_from_jax: positions = jnp.asarray(atoms.positions, dtype=jnp.float64) if np.any(atoms.get_cell().lengths() > 1e-6): box = atoms.cell.array.T inv_box = jnp.linalg.inv(box) positions = space.transform(inv_box, positions) # frac coords self.neighbors = self.neighbor_fn.allocate(positions, box=box) else: self.neighbors = self.neighbor_fn.allocate(positions) else: calculator = NeighborList(cutoff=self.r_max, full_list=True) (idxs_i,) = calculator.compute( points=atoms.positions, box=atoms.cell.array, periodic=bool(np.any(atoms.pbc)), quantities="i", ) self.padded_length = int(len(idxs_i) * self.padding_factor)
[docs] def get_descriptors( self, frames: list[ase.Atoms], processing_batch_size: int = 1, only_use_n_layers: int | None = None, should_average: bool = True, ) -> np.ndarray: """Compute the descriptors for a list of Atoms. Parameters ---------- frames : list[ase.Atoms] List of Atoms to compute descriptors for. processing_batch_size : int, default = 1 Batch size for processing the frames. This does not affect the results, only the speed and memory requirements. only_use_n_layers: int | None, default = None If specified, only the first `only_use_n_layers` layers of the feature model will be used to compute the descriptors. If None, all layers will be used. should_average: bool, default = True Whether to average the descriptors over the atomic species. Returns ------- np.ndarray Array of computed descriptors for each frame with shape (n_frames, n_descriptors) or (n_frames, n_atoms, n_descriptors) if ``should_average=False``. """ params = canonicalize_energy_model_parameters(self.params) dataset = OTFInMemoryDataset( frames, cutoff=self.model_config.model.basis.r_max, bs=processing_batch_size, n_epochs=1, ignore_labels=True, pos_unit=self.model_config.data.pos_unit, energy_unit=self.model_config.data.energy_unit, ) _, init_box = dataset.init_input() Builder = self.model_config.model.get_builder() builder = Builder(self.model_config.model.model_dump(), n_species=119) feature_model = builder.build_feature_model( apply_mask=True, init_box=init_box, only_use_n_layers=only_use_n_layers, should_average=should_average, ) feature_fn = feature_model.apply batched_feature_fn = jax.vmap(feature_fn, in_axes=(None, 0, 0, 0, 0, 0)) jitted_fn = jax.jit(batched_feature_fn) results = [] for batch in tqdm(dataset.batch(), ncols=100, total=len(frames)): data = jitted_fn( params, batch["positions"], batch["numbers"], batch["idx"], batch["box"], batch["offsets"], ) results.append(data) return np.concatenate(results, axis=0)
def set_neighbours_and_offsets(self, atoms: ase.Atoms, box: np.ndarray) -> None: calculator = NeighborList(cutoff=self.r_max, full_list=True) idxs_i, idxs_j, offsets = calculator.compute( points=atoms.positions, box=atoms.cell.array, periodic=bool(np.any(atoms.pbc)), quantities="ijS", ) if len(idxs_i) > self.padded_length: print("neighbor list overflowed, extending.") self.initialize(atoms) zeros_to_add = self.padded_length - len(idxs_i) self.neighbors = np.array([idxs_i, idxs_j], dtype=np.int32) self.neighbors = np.pad(self.neighbors, ((0, 0), (0, zeros_to_add)), "constant") offsets = np.matmul(offsets, box) self.offsets = np.pad(offsets, ((0, zeros_to_add), (0, 0)), "constant")
[docs] def calculate(self, atoms, properties=["energy"], system_changes=all_changes): Calculator.calculate(self, atoms, properties, system_changes) positions = jnp.asarray(atoms.positions, dtype=jnp.float64) box = jnp.asarray(atoms.cell.array, dtype=jnp.float64) # setup model and neighbours if self.step is None: self.initialize(atoms) elif "numbers" in system_changes: self.initialize(atoms) if self.neigbor_from_jax: self.neighbors = self.neighbor_fn.allocate(positions) elif "cell" in system_changes: neigbor_from_jax = neighbor_calculable_with_jax(box, self.r_max) if self.neigbor_from_jax != neigbor_from_jax: self.initialize(atoms) # predict if self.neigbor_from_jax: results, self.neighbors = self.step(positions, self.neighbors, box) if self.neighbors.did_buffer_overflow: print("neighbor list overflowed, reallocating.") self.initialize(atoms) results, self.neighbors = self.step(positions, self.neighbors, box) else: self.set_neighbours_and_offsets(atoms, box) positions = np.array(space.transform(np.linalg.inv(box), atoms.positions)) results = self.step(positions, self.neighbors, box, self.offsets) self.results = {k: np.array(v, dtype=np.float64) for k, v in results.items()} self.results["energy"] = self.results["energy"].item()
[docs] def batch_eval( self, atoms_list: list[ase.Atoms], batch_size: int = 64, silent: bool = False ) -> list[ase.Atoms]: """Evaluate the model on a list of Atoms. This is preferable to assigning the calculator to each Atoms instance for 2 reasons: 1. Processing can be abtched, which is advantageous for larger datasets. 2. Inputs are padded so no recompilation is triggered when evaluating differently sized systems. Parameters ---------- atoms_list : List of Atoms to be evaluated. batch_size: Processing batch size. Does not affect results, only speed and memory requirements. silent: Whether or not to suppress progress bars. Returns ------- evaluated_atoms_list: List of Atoms with labels predicted by the model. """ init_box = atoms_list[0].cell.array dataset = CachedInMemoryDataset( atoms_list, self.model_config.model.basis.r_max, batch_size, n_epochs=1, ignore_labels=True, ) evaluated_atoms_list = [] n_data = dataset.n_data ds = dataset.batch() Builder = self.model_config.model.get_builder() builder = Builder(self.model_config.model.model_dump()) model = builder.build_energy_derivative_model( apply_mask=True, init_box=init_box, ).apply if self.n_models > 1: model = jax.vmap(model, in_axes=(0, None, None, None, None, None)) model = partial(model, self.params) if self.n_models > 1: model = make_ensemble(model) if "stress" in self.implemented_properties: model = ProcessStress().apply(model) for transformation in self.transformations: model = transformation.apply(model) model = jax.vmap(model, in_axes=(0, 0, 0, 0, 0)) model = jax.jit(model) pbar = trange( n_data, desc="Evaluating data", ncols=100, leave=False, disable=silent ) for i, inputs in enumerate(ds): results = model( inputs["positions"], inputs["numbers"], inputs["idx"], inputs["box"], inputs["offsets"], ) unpadded_results = unpack_results(results, inputs) # for the last batch, the number of structures may be less # than the batch_size, which is why we check this explicitly num_strucutres_in_batch = results["energy"].shape[0] for j in range(num_strucutres_in_batch): atoms = atoms_list[i * batch_size + j].copy() atoms.calc = SinglePointCalculator(atoms=atoms, **unpadded_results[j]) evaluated_atoms_list.append(atoms) pbar.update(batch_size) pbar.close() dataset.cleanup() return evaluated_atoms_list
@property def ll_weights(self): dense_layers = list(self.params["params"]["energy_model"]["readout"].keys()) llweights = self.params["params"]["energy_model"]["readout"][dense_layers[-1]][ "w" ] return np.asarray(llweights) def set_ll_weights(self, new_weights): params = unfreeze(self.params) dense_layers = list(params["params"]["energy_model"]["readout"].keys()) params["params"]["energy_model"]["readout"][dense_layers[-1]]["w"] = jnp.asarray( new_weights, dtype=jnp.float32 ) self.params = freeze(params) self.step = None
def neighbor_calculable_with_jax(box: np.ndarray, r_max: float) -> bool: if np.all(box < 1e-6): return True else: # all lettice vector combinations to calculate all three plane distances a_vec_list = [box[0], box[0], box[1]] b_vec_list = [box[1], box[2], box[2]] c_vec_list = [box[2], box[1], box[0]] height = [] for i in range(3): normvec = np.cross(a_vec_list[i], b_vec_list[i]) projection = ( c_vec_list[i] - np.sum(normvec * c_vec_list[i]) / np.sum(normvec**2) * normvec ) height.append(np.linalg.norm(c_vec_list[i] - projection)) if np.min(height) / 2 > r_max: return True else: return False def get_step_fn( model: Callable, Z: jnp.ndarray, atoms_is_periodic: bool, neigbor_from_jax: bool ) -> Callable: if neigbor_from_jax: @jax.jit def step_fn(positions: jnp.ndarray, neighbor, box: jnp.ndarray): if atoms_is_periodic: box = box.T inv_box = jnp.linalg.inv(box) positions = space.transform(inv_box, positions) neighbor = neighbor.update(positions, box=box) else: neighbor = neighbor.update(positions) offsets = jnp.full([neighbor.idx.shape[1], 3], 0) results = model(positions, Z, neighbor.idx, box, offsets) return results, neighbor else: @jax.jit def step_fn(positions: jnp.ndarray, neighbor, box: jnp.ndarray, offsets): results = model(positions, Z, neighbor, box, offsets) return results return step_fn