Source code for apax.data.input_pipeline

import logging
import uuid
from collections import deque
from pathlib import Path
from random import shuffle
from typing import Dict, Iterator

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf

from apax.data.preprocessing import compute_nl, prefetch_to_single_device
from apax.utils.convert import atoms_to_inputs, atoms_to_labels, unit_dict

log = logging.getLogger(__name__)


[docs] def pad_nl(idx, offsets, max_neighbors): """ Pad the neighbor list arrays to the maximal number of neighbors occuring. Parameters ---------- idx : np.ndarray Neighbor indices array. offsets : np.ndarray Offset array. max_neighbors : int Maximum number of neighbors. Returns ------- Tuple[np.ndarray, np.ndarray] Tuple containing padded neighbor indices array and offsets array. """ zeros_to_add = max_neighbors - idx.shape[1] idx = np.pad(idx, ((0, 0), (0, zeros_to_add)), "constant").astype(np.int16) offsets = np.pad(offsets, ((0, zeros_to_add), (0, 0)), "constant") return idx, offsets
[docs] def find_largest_system(inputs, r_max) -> tuple[int]: """ Finds the maximal number of atoms and neighbors. Parameters ---------- inputs : dict Dictionary containing input data. r_max : float Maximum interaction radius. Returns ------- Tuple[int] Tuple containing the maximum number of atoms and neighbors. """ positions, boxes = inputs["positions"], inputs["box"] max_atoms = np.max(inputs["n_atoms"]) max_nbrs = 0 for pos, box in zip(positions, boxes): neighbor_idxs, _ = compute_nl(pos, box, r_max) n_neighbors = neighbor_idxs.shape[1] max_nbrs = max(max_nbrs, n_neighbors) return max_atoms, max_nbrs
class InMemoryDataset: def __init__( self, atoms_list, cutoff, bs, n_epochs, buffer_size=1000, n_jit_steps=1, pos_unit: str = "Ang", energy_unit: str = "eV", pre_shuffle=False, ignore_labels=False, cache_path=".", ) -> None: self.n_epochs = n_epochs self.cutoff = cutoff self.n_jit_steps = n_jit_steps self.buffer_size = buffer_size self.n_data = len(atoms_list) self.batch_size = self.validate_batch_size(bs) self.pos_unit = pos_unit if pre_shuffle: shuffle(atoms_list) self.sample_atoms = atoms_list[0] self.inputs = atoms_to_inputs(atoms_list, pos_unit) max_atoms, max_nbrs = find_largest_system(self.inputs, self.cutoff) self.max_atoms = max_atoms self.max_nbrs = max_nbrs if atoms_list[0].calc and not ignore_labels: self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit) else: self.labels = None self.count = 0 self.buffer = deque() self.file = Path(cache_path) / str(uuid.uuid4()) self.enqueue(min(self.buffer_size, self.n_data)) def steps_per_epoch(self) -> int: """Returns the number of steps per epoch dependent on the number of data and the batch size. Steps per epoch are calculated in a way that all epochs have the same number of steps, and all batches have the same length. To do so, some training data are dropped in each epoch. """ return self.n_data // self.batch_size // self.n_jit_steps def validate_batch_size(self, batch_size: int) -> int: if batch_size > self.n_data: msg = ( f"requested batch size {batch_size} is larger than the number of data" f" points {self.n_data}. Setting batch size = {self.n_data}" ) log.warning(msg) batch_size = self.n_data return batch_size def prepare_data(self, i): inputs = {k: v[i] for k, v in self.inputs.items()} idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) zeros_to_add = self.max_atoms - inputs["numbers"].shape[0] inputs["positions"] = np.pad( inputs["positions"], ((0, zeros_to_add), (0, 0)), "constant" ) inputs["numbers"] = np.pad( inputs["numbers"], (0, zeros_to_add), "constant" ).astype(np.int16) if not self.labels: return inputs labels = {k: v[i] for k, v in self.labels.items()} if "forces" in labels: labels["forces"] = np.pad( labels["forces"], ((0, zeros_to_add), (0, 0)), "constant" ) inputs = {k: tf.constant(v) for k, v in inputs.items()} labels = {k: tf.constant(v) for k, v in labels.items()} return (inputs, labels) def enqueue(self, num_elements): for _ in range(num_elements): data = self.prepare_data(self.count) self.buffer.append(data) self.count += 1 def make_signature(self) -> tf.TensorSpec: input_signature = {} input_signature["n_atoms"] = tf.TensorSpec((), dtype=tf.int16, name="n_atoms") input_signature["numbers"] = tf.TensorSpec( (self.max_atoms,), dtype=tf.int16, name="numbers" ) input_signature["positions"] = tf.TensorSpec( (self.max_atoms, 3), dtype=tf.float64, name="positions" ) input_signature["box"] = tf.TensorSpec((3, 3), dtype=tf.float64, name="box") input_signature["idx"] = tf.TensorSpec( (2, self.max_nbrs), dtype=tf.int16, name="idx" ) input_signature["offsets"] = tf.TensorSpec( (self.max_nbrs, 3), dtype=tf.float64, name="offsets" ) if not self.labels: return input_signature label_signature = {} if "energy" in self.labels.keys(): label_signature["energy"] = tf.TensorSpec((), dtype=tf.float64, name="energy") if "forces" in self.labels.keys(): label_signature["forces"] = tf.TensorSpec( (self.max_atoms, 3), dtype=tf.float64, name="forces" ) if "stress" in self.labels.keys(): label_signature["stress"] = tf.TensorSpec( (3, 3), dtype=tf.float64, name="stress" ) signature = (input_signature, label_signature) return signature def init_input(self) -> Dict[str, np.ndarray]: """Returns first batch of inputs and labels to init the model.""" positions = self.sample_atoms.positions * unit_dict[self.pos_unit] box = self.sample_atoms.cell.array * unit_dict[self.pos_unit] # For an input sample, it does not matter whether pos is fractional or cartesian idx, offsets = compute_nl(positions, box, self.cutoff) inputs = ( positions, self.sample_atoms.numbers, idx, box, offsets, ) inputs = jax.tree_map(lambda x: jnp.array(x), inputs) return inputs, np.array(box) def __iter__(self): raise NotImplementedError def shuffle_and_batch(self): raise NotImplementedError def batch(self) -> Iterator[jax.Array]: raise NotImplementedError def cleanup(self): pass class CachedInMemoryDataset(InMemoryDataset): def __iter__(self): while self.count < self.n_data or len(self.buffer) > 0: yield self.buffer.popleft() space = self.buffer_size - len(self.buffer) if self.count + space > self.n_data: space = self.n_data - self.count self.enqueue(space) def shuffle_and_batch(self, sharding=None): """Shuffles and batches the inputs/labels. This function prepares the inputs and labels for the whole training and prefetches the data. Returns ------- ds : Iterator that returns inputs and labels of one batch in each step. """ ds = ( tf.data.Dataset.from_generator( lambda: self, output_signature=self.make_signature() ) .cache(self.file.as_posix()) .repeat(self.n_epochs) ) ds = ds.shuffle( buffer_size=self.buffer_size, reshuffle_each_iteration=True ).batch(batch_size=self.batch_size) if self.n_jit_steps > 1: ds = ds.batch(batch_size=self.n_jit_steps) ds = prefetch_to_single_device( ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1 ) return ds def batch(self, sharding=None) -> Iterator[jax.Array]: ds = ( tf.data.Dataset.from_generator( lambda: self, output_signature=self.make_signature() ) .cache(self.file.as_posix()) .repeat(self.n_epochs) ) ds = ds.batch(batch_size=self.batch_size) ds = prefetch_to_single_device( ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1 ) return ds def cleanup(self): for p in self.file.parent.glob(f"{self.file.name}.data*"): p.unlink() index_file = self.file.parent / f"{self.file.name}.index" index_file.unlink() class OTFInMemoryDataset(InMemoryDataset): def __iter__(self): outer_count = 0 max_iter = self.n_data * self.n_epochs while outer_count < max_iter: yield self.buffer.popleft() space = self.buffer_size - len(self.buffer) if self.count + space > self.n_data: space = self.n_data - self.count if self.count >= self.n_data: self.count = 0 self.enqueue(space) outer_count += 1 def shuffle_and_batch(self, sharding=None): """Shuffles and batches the inputs/labels. This function prepares the inputs and labels for the whole training and prefetches the data. Returns ------- ds : Iterator that returns inputs and labels of one batch in each step. """ ds = tf.data.Dataset.from_generator( lambda: self, output_signature=self.make_signature() ) ds = ds.shuffle( buffer_size=self.buffer_size, reshuffle_each_iteration=True ).batch(batch_size=self.batch_size) if self.n_jit_steps > 1: ds = ds.batch(batch_size=self.n_jit_steps) ds = prefetch_to_single_device( ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1 ) return ds def batch(self, sharding=None) -> Iterator[jax.Array]: ds = tf.data.Dataset.from_generator( lambda: self, output_signature=self.make_signature() ) ds = ds.batch(batch_size=self.batch_size) ds = prefetch_to_single_device( ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1 ) return ds dataset_dict = { "cached": CachedInMemoryDataset, "otf": OTFInMemoryDataset, }