import dataclasses
from typing import List
import jax
import jax.numpy as jnp
import jax.scipy as jsc
import numpy as np
from apax.utils.math import inv_and_det_3x3, normed_dotp
[docs]
def weighted_squared_error(
label: jnp.array,
prediction: jnp.array,
name,
parameters: dict = {},
) -> jnp.array:
"""
Squared error function that allows weighting of
individual contributions by the number of atoms in the system.
"""
label, prediction = label[name], prediction[name]
return (label - prediction) ** 2
[docs]
def weighted_huber_loss(
label: jnp.array,
prediction: jnp.array,
name,
parameters: dict = {},
) -> jnp.array:
"""
Huber loss function that allows weighting of
individual contributions by the number of atoms in the system.
"""
label, prediction = label[name], prediction[name]
if "delta" not in parameters.keys():
raise KeyError("Huber loss function requires 'delta' parameter")
delta = parameters["delta"]
diff = jnp.abs(label - prediction)
loss = jnp.where(diff > delta, delta * (diff - 0.5 * delta), 0.5 * diff**2)
return loss
[docs]
def crps_loss(
label: jax.Array,
prediction: jax.Array,
name,
parameters: dict = {},
) -> jax.Array:
"""Computes the CRPS of a gaussian distribution given
means, targets and standard deviations (uncertainty estimate)
"""
label = label[name]
means = prediction[name]
sigmas = prediction[name + "_uncertainty"]
sigmas = jnp.clip(sigmas, min=1e-6)
norm_x = (label - means) / sigmas
cdf = 0.5 * (1 + jsc.special.erf(norm_x / jnp.sqrt(2)))
normalization = 1 / (jnp.sqrt(2.0 * np.pi))
pdf = normalization * jnp.exp(-(norm_x**2) / 2.0)
crps = sigmas * (norm_x * (2 * cdf - 1) + 2 * pdf - 1 / jnp.sqrt(np.pi))
return crps
[docs]
def nll_loss(
label: jax.Array,
prediction: jax.Array,
name,
parameters: dict = {},
) -> jax.Array:
"""Computes the gaussian NLL loss given
means, targets and standard deviations (uncertainty estimate)
"""
label = label[name]
means = prediction[name]
sigmas = prediction[name + "_uncertainty"]
eps = 1e-6
sigmas = jnp.clip(sigmas, min=eps)
variances = jnp.pow(sigmas, 2)
x1 = jnp.log(variances)
x2 = ((means - label) ** 2) / variances
nll = 0.5 * (x1 + x2)
return nll
[docs]
def force_angle_loss(
label: jnp.array,
prediction: jnp.array,
name,
parameters: dict = {},
) -> jnp.array:
"""
Consine similarity loss function. Contributions are summed in `Loss`.
"""
label, prediction = label[name], prediction[name]
dotp = normed_dotp(label, prediction)
return 1.0 - dotp
[docs]
def force_angle_div_force_label(
label: jnp.array,
prediction: jnp.array,
name,
parameters: dict = {},
):
"""
Consine similarity loss function weighted by the norm of the force labels.
Contributions are summed in `Loss`.
"""
label, prediction = label[name], prediction[name]
dotp = normed_dotp(label, prediction)
F_0_norm = jnp.linalg.norm(label, ord=2, axis=2, keepdims=False)
loss = jnp.where(F_0_norm > 1e-6, (1.0 - dotp) / F_0_norm, jnp.zeros_like(dotp))
return loss
[docs]
def force_angle_exponential_weight(
label: jnp.array,
prediction: jnp.array,
name,
parameters: dict = {},
) -> jnp.array:
"""
Consine similarity loss function exponentially scaled by the norm of the force labels.
Contributions are summed in `Loss`.
"""
label, prediction = label[name], prediction[name]
dotp = normed_dotp(label, prediction)
F_0_norm = jnp.linalg.norm(label, ord=2, axis=2, keepdims=False)
return (1.0 - dotp) * jnp.exp(-F_0_norm)
def stress_tril(label, prediction, name, parameters: dict = {}):
label, prediction = label[name], prediction[name]
idxs = jnp.tril_indices(3)
label_tril = label[:, idxs[0], idxs[1]]
prediction_tril = prediction[:, idxs[0], idxs[1]]
return (label_tril - prediction_tril) ** 2
def nll_3x3(label, prediction, name, parameters: dict = {}):
label = label[name]
means = prediction[name]
ensemble = prediction[name + "_ensemble"]
diff = label - means
deviations = ensemble - means[..., None]
# K = deviations.shape[2] # Number of members
K = deviations.shape[-1] # Number of ensemble members
if K < 2:
raise ValueError("nll_3x3 requires at least 2 ensemble members")
Sigma = jnp.einsum("bijk,bilk->bijl", deviations, deviations) / (
K - 1
) # Sample covariance matrix
Sigma = Sigma + jnp.eye(3)[None, None, ...] * 1e-5
Sigma_inv, det = inv_and_det_3x3(Sigma)
det = jnp.maximum(det, 1e-12)
log_det = jnp.log(det)
# diff.T @ Sigma_inv @ diff
# Einsum: (N, 3) * (N, 3, 3) * (N, 3) -> (N,)
# z = Sigma_inv @ diff
z = jnp.einsum("...ij, ...j -> ...i", Sigma_inv, diff)
mahalanobis = jnp.sum(diff * z, axis=-1)
nll = 0.5 * (mahalanobis + log_det)
return nll
loss_functions = {
"mse": weighted_squared_error,
"huber": weighted_huber_loss,
"cosine_sim": force_angle_loss,
"cosine_sim_div_magnitude": force_angle_div_force_label,
"cosine_sim_exp_magnitude": force_angle_exponential_weight,
"tril": stress_tril,
"crps": crps_loss,
"nll": nll_loss,
"nll_3x3": nll_3x3,
}
[docs]
@dataclasses.dataclass
class Loss:
"""
Represents a single weighted loss function that is constructed from a `name`
and a type of comparison metric.
"""
name: str
loss_type: str
weight: float = 1.0
atoms_exponent: float = 1.0
parameters: dict = dataclasses.field(default_factory=lambda: {})
def __post_init__(self):
if self.loss_type not in loss_functions.keys():
raise NotImplementedError(
f"the loss function '{self.loss_type}' is not known."
)
self.loss_fn = loss_functions[self.loss_type]
def __call__(self, inputs: dict, prediction: dict, label: dict) -> float:
# TODO we may want to insert an additional `mask` argument for this method
divisor = inputs["n_atoms"] ** self.atoms_exponent
batch_losses = self.loss_fn(label, prediction, self.name, self.parameters)
axes_to_add = len(batch_losses.shape) - 1
for _ in range(axes_to_add):
divisor = divisor[..., None]
arg = batch_losses / divisor
loss = self.weight * jnp.sum(jnp.mean(arg, axis=0))
return loss
[docs]
@dataclasses.dataclass
class LossCollection:
loss_list: List[Loss]
def __call__(self, inputs: dict, predictions: dict, labels: dict) -> float:
total_loss = 0.0
for single_loss_fn in self.loss_list:
loss = single_loss_fn(inputs, predictions, labels)
total_loss = total_loss + loss
return total_loss