Source code for apax.layers.scaling
from typing import Any, Union
import einops
import flax.linen as nn
import jax.numpy as jnp
from apax.utils.convert import str_to_dtype
[docs]
class PerElementScaleShift(nn.Module):
n_species: int = 119
scale: Union[jnp.array, float] = 1.0
shift: Union[jnp.array, float] = 0.0
dtype: Any = jnp.float32
[docs]
def setup(self):
scale = jnp.asarray(self.scale)
shift = jnp.asarray(self.shift)
if len(scale.shape) > 0:
n_species = scale.shape[0]
else:
n_species = self.n_species
if len(scale.shape) == 1:
scale = einops.repeat(scale, "species -> species 1")
if len(shift.shape) == 1:
shift = einops.repeat(shift, "species -> species 1")
scale_init = nn.initializers.constant(scale)
shift_init = nn.initializers.constant(shift)
dtype = str_to_dtype(self.dtype)
self.scale_param = self.param(
"scale_per_element", scale_init, (n_species, 1), dtype
)
self.shift_param = self.param(
"shift_per_element", shift_init, (n_species, 1), dtype
)
def __call__(self, x, Z):
dtype = str_to_dtype(self.dtype)
# x shape: n_atoms x 1
# Z shape: n_atoms
# scale[Z] shape: n_atoms x 1
x = x.astype(dtype)
out = self.scale_param[Z] * x + self.shift_param[Z]
assert out.dtype == dtype
return out