import collections
import itertools
import logging
import jax
import jax.numpy as jnp
import numpy as np
from matscipy.neighbours import neighbour_list
log = logging.getLogger(__name__)
[docs]
def compute_nl(positions, box, r_max):
"""
Computes the neighbor list for a single structure.
For periodic systems, positions are assumed to be in
fractional coordinates.
Parameters
----------
positions : np.ndarray
Positions of atoms.
box : np.ndarray
Simulation box dimensions.
r_max : float
Maximum interaction radius.
Returns
-------
Tuple[np.ndarray, np.ndarray]
Tuple containing neighbor indices array and offsets array.
"""
if np.all(box < 1e-6):
box, box_origin = get_shrink_wrapped_cell(positions)
idxs_i, idxs_j = neighbour_list(
"ij",
positions=positions,
cutoff=r_max,
cell=box,
cell_origin=box_origin,
pbc=[False, False, False],
)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)
n_neighbors = neighbor_idxs.shape[1]
offsets = np.full([n_neighbors, 3], 0)
else:
positions = positions @ box
idxs_i, idxs_j, offsets = neighbour_list(
"ijS", positions=positions, cutoff=r_max, cell=box, pbc=[True, True, True]
)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)
offsets = np.matmul(offsets, box)
return neighbor_idxs, offsets
[docs]
def get_shrink_wrapped_cell(positions):
"""
Get the shrink-wrapped simulation cell based on atomic positions.
Parameters
----------
positions : np.ndarray
Atomic positions.
Returns
-------
Tuple[np.ndarray, np.ndarray]
Tuple containing the shrink-wrapped cell matrix and origin.
"""
rmin = np.min(positions, axis=0)
rmax = np.max(positions, axis=0)
cell_origin = rmin
cell = np.diag(rmax - rmin)
for idx in range(3):
if cell[idx, idx] < 10e-1:
cell[idx, idx] = 1.0
cell[np.diag_indices_from(cell)] += 1
return cell, cell_origin
[docs]
def prefetch_to_single_device(iterator, size: int, sharding=None, n_step_jit=False):
"""
inspired by
https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device
except it does not shard the data.
"""
queue = collections.deque()
if sharding:
n_devices = len(sharding._devices)
slice_start = 1
shape = [n_devices]
if n_step_jit:
# replicate over multi-batch axis
# data shape: njit x bs x ...
slice_start = 2
shape.insert(0, 1)
def _prefetch(x: jax.Array):
if sharding:
remaining_axes = [1] * len(x.shape[slice_start:])
final_shape = tuple(shape + remaining_axes)
x = jax.device_put(x, sharding.reshape(final_shape))
else:
x = jnp.asarray(x)
return x
def enqueue(n):
for data in itertools.islice(iterator, n):
queue.append(jax.tree_util.tree_map(_prefetch, data))
enqueue(size)
while queue:
yield queue.popleft()
enqueue(1)