Data Pipeline#

apax.data

apax.data.preprocessing.compute_nl(positions, box, r_max)[source]#

Computes the neighbor list for a single structure. For periodic systems, positions are assumed to be in fractional coordinates.

Parameters:
  • positions (np.ndarray) – Positions of atoms.

  • box (np.ndarray) – Simulation box dimensions.

  • r_max (float) – Maximum interaction radius.

Returns:

Tuple containing neighbor indices array and offsets array.

Return type:

Tuple[np.ndarray, np.ndarray]

apax.data.preprocessing.get_shrink_wrapped_cell(positions)[source]#

Get the shrink-wrapped simulation cell based on atomic positions.

Parameters:

positions (np.ndarray) – Atomic positions.

Returns:

Tuple containing the shrink-wrapped cell matrix and origin.

Return type:

Tuple[np.ndarray, np.ndarray]

apax.data.preprocessing.prefetch_to_single_device(iterator, size: int, sharding=None, n_step_jit=False)[source]#

inspired by https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device except it does not shard the data.

apax.data.initialization.load_data_files(data_config)[source]#

Load data files for training and validation.

Parameters:

data_config (object) – Data configuration object.

Returns:

Tuple containing list of ase.Atoms objects for training and validation.

Return type:

Tuple

apax.data.input_pipeline.find_largest_system(inputs, r_max) tuple[int][source]#

Finds the maximal number of atoms and neighbors.

Parameters:
  • inputs (dict) – Dictionary containing input data.

  • r_max (float) – Maximum interaction radius.

Returns:

Tuple containing the maximum number of atoms and neighbors.

Return type:

Tuple[int]

apax.data.input_pipeline.pad_nl(idx, offsets, max_neighbors)[source]#

Pad the neighbor list arrays to the maximal number of neighbors occuring.

Parameters:
  • idx (np.ndarray) – Neighbor indices array.

  • offsets (np.ndarray) – Offset array.

  • max_neighbors (int) – Maximum number of neighbors.

Returns:

Tuple containing padded neighbor indices array and offsets array.

Return type:

Tuple[np.ndarray, np.ndarray]

class apax.data.statistics.DatasetStats(elemental_shift: <built-in function array> = None, elemental_scale: float = None, n_species: int = 119)[source]#