Source code for apax.md.sim_utils

import dataclasses
from typing import Callable, Optional

import jax.numpy as jnp
import numpy as np
from jax_md import space


[docs] @dataclasses.dataclass class System: atomic_numbers: jnp.array masses: jnp.array positions: jnp.array box: jnp.array momenta: Optional[jnp.array] @classmethod def from_atoms(cls, atoms): atomic_numbers = jnp.asarray(atoms.numbers, dtype=jnp.int32) masses = jnp.asarray(atoms.get_masses(), dtype=jnp.float64) momenta = atoms.get_momenta() box = jnp.asarray(atoms.cell.array, dtype=jnp.float64) box = box.T positions = jnp.asarray(atoms.positions, dtype=jnp.float64) if np.any(box > 1e-6): positions = space.transform(jnp.linalg.inv(box), positions) system = cls( atomic_numbers=atomic_numbers, masses=masses, positions=positions, box=box, momenta=momenta, ) return system
[docs] @dataclasses.dataclass class SimulationFunctions: energy_fn: Callable auxiliary_fn: Callable shift_fn: Callable neighbor_fn: Callable