Source code for apax.utils.convert
import logging
import jax.numpy as jnp
import numpy as np
from ase import Atoms
from ase.units import Ang, Bohr, Hartree, eV, kcal, kJ, mol
log = logging.getLogger(__name__)
DTYPE = np.float64
unit_dict = {
"Ang": Ang,
"Bohr": Bohr,
"eV": eV,
"kcal/mol": kcal / mol,
"Hartree": Hartree,
"kJ/mol": kJ / mol,
}
def str_to_dtype(x):
if isinstance(x, str):
if x == "fp32":
y = jnp.float32
elif x == "fp64":
y = jnp.float64
elif x == "fp16":
y = jnp.float16
elif x == "fp128":
y = jnp.float128
else:
raise KeyError(f"unknown dtype {x}")
return y
else:
return x
[docs]
def tf_to_jax_dict(data_dict: dict[str, list]) -> dict:
"""Converts a dict of tf.Tensors to a dict of jax.numpy.arrays.
tf.Tensors must be padded.
Parameters
----------
data_dict :
Dict padded of tf.Tensors
Returns
-------
data_dict :
Dict of jax.numpy.arrays
"""
data_dict = {k: jnp.asarray(v) for k, v in data_dict.items()}
return data_dict
def prune_dict(data_dict):
pruned = {key: val for key, val in data_dict.items() if len(val) != 0}
return pruned
def is_periodic(box):
pbc_dims = np.any(np.abs(box) > 1e-6)
if np.all(pbc_dims == True) or np.all(pbc_dims == False): # noqa: E712
return pbc_dims
else:
msg = (
f"Only 3D periodic and gas phase system supported at the moment. Found {box}"
)
raise ValueError(msg)
[docs]
def atoms_to_labels(
atoms_list: list[Atoms],
pos_unit: str = "Ang",
energy_unit: str = "eV",
additional_properties: list[str] = [],
) -> dict[str, dict[str, list]]:
"""Converts an list of ASE atoms to a dict of labels
Units are adjusted if ASE compatible and provided in the inputpipeline.
Parameters
----------
atoms_list :
List of all structures. Entries are ASE atoms objects.
Returns
-------
labels :
Labels are trainable system properties.
"""
labels = {}
property_names = [p[0] for p in additional_properties]
for key in property_names:
if key not in labels.keys():
placeholder = {key: []}
labels.update(placeholder)
common_keys = set(atoms_list[0].calc.results.keys())
for atoms in atoms_list[1:]:
common_keys &= set(atoms.calc.results.keys())
log.info(f"Labels found in the dataset: {common_keys}")
for key in labels.keys():
if key not in common_keys:
log.error(f"Label {key} missing at least in one structure")
for key in common_keys:
if key not in labels.keys():
placeholder = {key: []}
labels.update(placeholder)
for atoms in atoms_list:
for key in common_keys:
val = atoms.calc.results[key]
if key == "forces":
labels[key].append(val * unit_dict[energy_unit] / unit_dict[pos_unit])
elif key == "energy":
labels[key].append(val * unit_dict[energy_unit])
elif key == "stress":
factor = unit_dict[energy_unit] / (unit_dict[pos_unit] ** 3)
stress = atoms.get_stress(voigt=False) * factor
labels[key].append(stress * atoms.cell.volume)
elif key in property_names:
if len(val.shape) == 1 and val.shape[0] != 1:
val = val[:, None]
labels[key].append(val)
labels = prune_dict(labels)
return labels
def transpose_dict_of_lists(dict_of_lists: dict):
list_of_dicts = []
keys = list(dict_of_lists.keys())
for i in range(len(dict_of_lists[keys[0]])):
data = {k: dict_of_lists[k][i] for k in keys}
list_of_dicts.append(data)
return list_of_dicts