Utils#
- apax.utils.convert.atoms_to_inputs(atoms_list: list[Atoms], pos_unit: str = 'Ang') dict[str, dict[str, list]][source]#
Converts an list of ASE atoms to a dict where all inputs are sorted by their shape (ragged/fixed). Units are adjusted if ASE compatible and provided in the inputpipeline.
- Parameters:
atoms_list – List of all structures. Enties are ASE atoms objects.
- Returns:
inputs – Inputs are untrainable system-determining properties.
labels – Labels are trainable system properties.
- apax.utils.convert.atoms_to_labels(atoms_list: list[Atoms], pos_unit: str = 'Ang', energy_unit: str = 'eV') dict[str, dict[str, list]][source]#
Converts an list of ASE atoms to a dict of labels Units are adjusted if ASE compatible and provided in the inputpipeline.
- Parameters:
atoms_list – List of all structures. Enties are ASE atoms objects.
- Returns:
Labels are trainable system properties.
- Return type:
labels
- apax.utils.convert.tf_to_jax_dict(data_dict: dict[str, list]) dict[source]#
Converts a dict of tf.Tensors to a dict of jax.numpy.arrays. tf.Tensors must be padded.
- Parameters:
data_dict – Dict padded of tf.Tensors
- Returns:
Dict of jax.numpy.arrays
- Return type:
data_dict
- apax.utils.data.load_data(data_path)[source]#
Non ASE compatible parameters have to be saved in an exta file that has the same name as the datapath but with the extension _labels.npz.
Example
example for the npz-file:
dipole = np.random.rand(3, 1) charge = np.random.rand(3, 2) mat = np.random.rand(3, 1) shape = ['ragged', 'ragged', 'fixed'] np.savez( "data_path_labels.npz", dipole=dipole, charge=charge, mat=mat, shape=shape, )
shape has to be in the same order than the parameters
- Parameters:
data_path – Path to the ASE readable file that includes all structures.
- Returns:
List of all structures where entries are ASE atoms objects.
- Return type:
list
- apax.utils.data.split_atoms(atoms_list, train_idxs, val_idxs=None)[source]#
Split the list of atoms into training and validation sets (validation is optional).
- Parameters:
atoms_list (list[ase.Atoms]) – List of atoms.
train_idxs (list[int]) – List of indices for the training set.
val_idxs (list[int], optional) – List of indices for the validation set.
- Returns:
Tuple containing lists of atoms for training and validation sets.
- Return type:
Tuple[list, list]