Source code for apax.data.preprocessing

import collections
import itertools
import logging

import jax
import jax.numpy as jnp
import numpy as np
from jax import tree_util
from vesin import NeighborList

log = logging.getLogger(__name__)


[docs] def compute_nl(positions, box, r_max): """ Computes the neighbor list for a single structure. For periodic systems, positions are assumed to be in fractional coordinates. Parameters ---------- positions : np.ndarray Positions of atoms. box : np.ndarray Simulation box dimensions. r_max : float Maximum interaction radius. Returns ------- Tuple[np.ndarray, np.ndarray] Tuple containing neighbor indices array and offsets array. """ if np.all(box < 1e-6): box, _ = get_shrink_wrapped_cell(positions) calculator = NeighborList(cutoff=r_max, full_list=True) idxs_i, idxs_j = calculator.compute( points=positions, box=box, periodic=False, quantities="ij" ) neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16) n_neighbors = neighbor_idxs.shape[1] offsets = np.full([n_neighbors, 3], 0) else: positions = positions @ box.T calculator = NeighborList(cutoff=r_max, full_list=True) idxs_i, idxs_j, offsets = calculator.compute( points=positions, box=box.T, periodic=True, quantities="ijS" ) neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32) offsets = np.matmul(offsets, box.T) return neighbor_idxs, offsets
[docs] def get_shrink_wrapped_cell(positions): """ Get the shrink-wrapped simulation cell based on atomic positions. Parameters ---------- positions : np.ndarray Atomic positions. Returns ------- Tuple[np.ndarray, np.ndarray] Tuple containing the shrink-wrapped cell matrix and origin. """ rmin = np.min(positions, axis=0) rmax = np.max(positions, axis=0) cell_origin = rmin cell = np.diag(rmax - rmin) for idx in range(3): if cell[idx, idx] < 10e-1: cell[idx, idx] = 1.0 cell[np.diag_indices_from(cell)] += 1 return cell, cell_origin
[docs] def prefetch_to_single_device(iterator, size: int, data_sharding=None): """ inspired by https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device """ queue = collections.deque() def _prefetch(x: jax.Array): if data_sharding: x = jax.device_put(x, data_sharding) else: x = jnp.asarray(x) return x def enqueue(n): for data in itertools.islice(iterator, n): queue.append(tree_util.tree_map(_prefetch, data)) enqueue(size) while queue: yield queue.popleft() enqueue(1)