Source code for apax.utils.math

from typing import Iterable, Optional, Union

import jax.numpy as jnp
from jax import Array


def fp64_sum(
    X: Array, axis: Optional[Union[Iterable[int], int]] = None, keepdims: bool = False
):
    dtyp = jnp.float64
    result = jnp.sum(X, axis=axis, dtype=dtyp, keepdims=keepdims)
    return result


def normed_dotp(F_0, F_pred):
    F_0_norm = jnp.linalg.norm(F_0, ord=2, axis=2, keepdims=True)
    F_p_norm = jnp.linalg.norm(F_pred, ord=2, axis=2, keepdims=True)

    F_0_n = jnp.where(F_0_norm > 1e-6, F_0 / F_0_norm, jnp.zeros_like(F_0))
    F_p_n = jnp.where(F_p_norm > 1e-6, F_pred / F_p_norm, jnp.zeros_like(F_pred))

    dotp = jnp.einsum("bai, bai -> ba", F_0_n, F_p_n)
    return dotp


[docs] def center_of_mass(positions: Array, masses: Array) -> Array: """Calculate the center of mass from arrays of positions and masses. Args: positions (Array): array of coordinates with shape N*3 masses (Array): array of point masses with shape N Returns: Array: center of mass coordinates with shape 3 """ return jnp.sum(masses[:, None] * positions, axis=0) / jnp.sum(masses)
def inv_and_det_3x3(Sigma: Array) -> tuple[Array, Array]: a00 = Sigma[..., 0, 0] a01 = Sigma[..., 0, 1] a02 = Sigma[..., 0, 2] a10 = Sigma[..., 1, 0] # Sym: a01 a11 = Sigma[..., 1, 1] a12 = Sigma[..., 1, 2] a20 = Sigma[..., 2, 0] # Sym: a02 a21 = Sigma[..., 2, 1] # Sym: a12 a22 = Sigma[..., 2, 2] # 3. Analytical Determinant det = ( a00 * (a11 * a22 - a12 * a21) - a01 * (a10 * a22 - a12 * a20) + a02 * (a10 * a21 - a11 * a20) ) det_safe = jnp.maximum(jnp.abs(det), 1e-6) * jnp.sign(det) invDet = 1.0 / det_safe inv00 = (a11 * a22 - a12 * a21) * invDet inv01 = (a02 * a21 - a01 * a22) * invDet inv02 = (a01 * a12 - a02 * a11) * invDet inv10 = (a12 * a20 - a10 * a22) * invDet inv11 = (a00 * a22 - a02 * a20) * invDet inv12 = (a10 * a02 - a00 * a12) * invDet inv20 = (a10 * a21 - a11 * a20) * invDet # Same as inv02 if symmetric inv21 = (a20 * a01 - a00 * a21) * invDet # Same as inv12 if symmetric inv22 = (a00 * a11 - a01 * a10) * invDet # Reconstruct Inverse Matrix (N, 3, 3) # Stack is faster than assignment row0 = jnp.stack([inv00, inv01, inv02], axis=-1) row1 = jnp.stack([inv10, inv11, inv12], axis=-1) row2 = jnp.stack([inv20, inv21, inv22], axis=-1) Sigma_inv = jnp.stack([row0, row1, row2], axis=-2) return Sigma_inv, det