Source code for apax.utils.convert
import jax.numpy as jnp
import numpy as np
from ase import Atoms
from ase.units import Ang, Bohr, Hartree, eV, kcal, kJ, mol
from apax.utils.jax_md_reduced import space
DTYPE = np.float64
unit_dict = {
"Ang": Ang,
"Bohr": Bohr,
"eV": eV,
"kcal/mol": kcal / mol,
"Hartree": Hartree,
"kJ/mol": kJ / mol,
}
[docs]
def tf_to_jax_dict(data_dict: dict[str, list]) -> dict:
"""Converts a dict of tf.Tensors to a dict of jax.numpy.arrays.
tf.Tensors must be padded.
Parameters
----------
data_dict :
Dict padded of tf.Tensors
Returns
-------
data_dict :
Dict of jax.numpy.arrays
"""
data_dict = {k: jnp.asarray(v) for k, v in data_dict.items()}
return data_dict
def prune_dict(data_dict):
pruned = {key: val for key, val in data_dict.items() if len(val) != 0}
return pruned
def is_periodic(box):
pbc_dims = np.any(np.abs(box) > 1e-6)
if np.all(pbc_dims == True) or np.all(pbc_dims == False): # noqa: E712
return pbc_dims
else:
msg = (
f"Only 3D periodic and gas phase system supported at the moment. Found {box}"
)
raise ValueError(msg)
[docs]
def atoms_to_inputs(
atoms_list: list[Atoms],
pos_unit: str = "Ang",
) -> dict[str, dict[str, list]]:
"""Converts an list of ASE atoms to a dict where all inputs
are sorted by their shape (ragged/fixed). Units are
adjusted if ASE compatible and provided in the inputpipeline.
Parameters
----------
atoms_list :
List of all structures. Enties are ASE atoms objects.
Returns
-------
inputs :
Inputs are untrainable system-determining properties.
labels :
Labels are trainable system properties.
"""
inputs = {
"positions": [],
"numbers": [],
"n_atoms": [],
"box": [],
}
box = atoms_list[0].cell.array
pbc = is_periodic(box)
for atoms in atoms_list:
box = (atoms.cell.array * unit_dict[pos_unit]).astype(DTYPE)
box = box.T # takes row and column convention of ase into account
inputs["box"].append(box)
is_pbc = is_periodic(box)
if pbc != is_pbc:
raise ValueError(
"Apax does not support dataset periodic and non periodic structures"
)
if is_pbc:
inv_box = np.linalg.inv(box)
pos = (atoms.positions * unit_dict[pos_unit]).astype(DTYPE)
frac_pos = space.transform(inv_box, pos)
inputs["positions"].append(np.array(frac_pos))
else:
inputs["positions"].append(
(atoms.positions * unit_dict[pos_unit]).astype(DTYPE)
)
inputs["numbers"].append(atoms.numbers.astype(np.int16))
inputs["n_atoms"].append(len(atoms))
inputs = prune_dict(inputs)
return inputs
[docs]
def atoms_to_labels(
atoms_list: list[Atoms],
pos_unit: str = "Ang",
energy_unit: str = "eV",
) -> dict[str, dict[str, list]]:
"""Converts an list of ASE atoms to a dict of labels
Units are adjusted if ASE compatible and provided in the inputpipeline.
Parameters
----------
atoms_list :
List of all structures. Enties are ASE atoms objects.
Returns
-------
labels :
Labels are trainable system properties.
"""
labels = {
"forces": [],
"energy": [],
"stress": [],
}
# for key in atoms_list[0].calc.results.keys():
# if key not in labels.keys():
# placeholder = {key: []}
# labels.update(placeholder)
for atoms in atoms_list:
for key, val in atoms.calc.results.items():
if key == "forces":
labels[key].append(val * unit_dict[energy_unit] / unit_dict[pos_unit])
elif key == "energy":
labels[key].append(val * unit_dict[energy_unit])
elif key == "stress":
factor = unit_dict[energy_unit] / (unit_dict[pos_unit] ** 3)
stress = atoms.get_stress(voigt=False) * factor
labels[key].append(stress * atoms.cell.volume)
# else:
# labels[key].append(atoms.calc.results[key])
labels = prune_dict(labels)
return labels