Source code for apax.data.preprocessing
import collections
import itertools
import logging
import jax
import jax.numpy as jnp
import numpy as np
from jax import tree_util
from vesin import NeighborList
log = logging.getLogger(__name__)
[docs]
def compute_nl(positions, box, r_max):
"""
Computes the neighbor list for a single structure.
For periodic systems, positions are assumed to be in
fractional coordinates.
Parameters
----------
positions : np.ndarray
Positions of atoms.
box : np.ndarray
Simulation box dimensions.
r_max : float
Maximum interaction radius.
Returns
-------
Tuple[np.ndarray, np.ndarray]
Tuple containing neighbor indices array and offsets array.
"""
if np.all(box < 1e-6):
box, _ = get_shrink_wrapped_cell(positions)
calculator = NeighborList(cutoff=r_max, full_list=True)
idxs_i, idxs_j = calculator.compute(
points=positions, box=box, periodic=False, quantities="ij"
)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)
n_neighbors = neighbor_idxs.shape[1]
offsets = np.full([n_neighbors, 3], 0)
else:
positions = positions @ box.T
calculator = NeighborList(cutoff=r_max, full_list=True)
idxs_i, idxs_j, offsets = calculator.compute(
points=positions, box=box.T, periodic=True, quantities="ijS"
)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32)
offsets = np.matmul(offsets, box.T)
return neighbor_idxs, offsets
[docs]
def get_shrink_wrapped_cell(positions):
"""
Get the shrink-wrapped simulation cell based on atomic positions.
Parameters
----------
positions : np.ndarray
Atomic positions.
Returns
-------
Tuple[np.ndarray, np.ndarray]
Tuple containing the shrink-wrapped cell matrix and origin.
"""
rmin = np.min(positions, axis=0)
rmax = np.max(positions, axis=0)
cell_origin = rmin
cell = np.diag(rmax - rmin)
for idx in range(3):
if cell[idx, idx] < 10e-1:
cell[idx, idx] = 1.0
cell[np.diag_indices_from(cell)] += 1
return cell, cell_origin
[docs]
def prefetch_to_single_device(iterator, size: int, data_sharding=None):
"""
inspired by
https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device
"""
queue = collections.deque()
def _prefetch(x: jax.Array):
if data_sharding:
x = jax.device_put(x, data_sharding)
else:
x = jnp.asarray(x)
return x
def enqueue(n):
for data in itertools.islice(iterator, n):
queue.append(tree_util.tree_map(_prefetch, data))
enqueue(size)
while queue:
yield queue.popleft()
enqueue(1)