import jax.numpy as jnp
import numpy as np
from ase import Atoms
from ase.units import Ang, Bohr, Hartree, eV, kcal, kJ, mol
from apax.utils.jax_md_reduced import space
DTYPE = np.float64
unit_dict = {
"Ang": Ang,
"Bohr": Bohr,
"eV": eV,
"kcal/mol": kcal / mol,
"Hartree": Hartree,
"kJ/mol": kJ / mol,
}
[docs]
def tf_to_jax_dict(data_dict: dict[str, list]) -> dict:
"""Converts a dict of tf.Tensors to a dict of jax.numpy.arrays.
tf.Tensors must be padded.
Parameters
----------
data_dict :
Dict padded of tf.Tensors
Returns
-------
data_dict :
Dict of jax.numpy.arrays
"""
data_dict = {k: jnp.asarray(v) for k, v in data_dict.items()}
return data_dict
def prune_dict(data_dict):
pruned = {key: val for key, val in data_dict.items() if len(val) != 0}
return pruned
def is_periodic(box):
pbc_dims = np.any(np.abs(box) > 1e-6)
if np.all(pbc_dims == True) or np.all(pbc_dims == False): # noqa: E712
return pbc_dims
else:
msg = (
f"Only 3D periodic and gas phase system supported at the moment. Found {box}"
)
raise ValueError(msg)
[docs]
def atoms_to_labels(
atoms_list: list[Atoms],
pos_unit: str = "Ang",
energy_unit: str = "eV",
) -> dict[str, dict[str, list]]:
"""Converts an list of ASE atoms to a dict of labels
Units are adjusted if ASE compatible and provided in the inputpipeline.
Parameters
----------
atoms_list :
List of all structures. Enties are ASE atoms objects.
Returns
-------
labels :
Labels are trainable system properties.
"""
labels = {
"forces": [],
"energy": [],
"stress": [],
}
# for key in atoms_list[0].calc.results.keys():
# if key not in labels.keys():
# placeholder = {key: []}
# labels.update(placeholder)
for atoms in atoms_list:
for key, val in atoms.calc.results.items():
if key == "forces":
labels[key].append(val * unit_dict[energy_unit] / unit_dict[pos_unit])
elif key == "energy":
labels[key].append(val * unit_dict[energy_unit])
elif key == "stress":
factor = unit_dict[energy_unit] / (unit_dict[pos_unit] ** 3)
stress = atoms.get_stress(voigt=False) * factor
labels[key].append(stress * atoms.cell.volume)
# else:
# labels[key].append(atoms.calc.results[key])
labels = prune_dict(labels)
return labels
def transpose_dict_of_lists(dict_of_lists: dict):
list_of_dicts = []
keys = list(dict_of_lists.keys())
for i in range(len(dict_of_lists[keys[0]])):
data = {k: dict_of_lists[k][i] for k in keys}
list_of_dicts.append(data)
return list_of_dicts