Source code for apax.layers.readout

from dataclasses import field
from typing import Any, Callable, List

import flax.linen as nn
import jax.numpy as jnp

from apax.layers.activation import swish
from apax.layers.ntk_linear import NTKLinear


[docs] class AtomisticReadout(nn.Module): units: List[int] = field(default_factory=lambda: [512, 512]) activation_fn: Callable = swish b_init: str = "normal" dtype: Any = jnp.float32
[docs] def setup(self): units = [u for u in self.units] + [1] dense = [] for ii, n_hidden in enumerate(units): dense.append( NTKLinear( n_hidden, b_init=self.b_init, dtype=self.dtype, name=f"dense_{ii}" ) ) if ii < len(units) - 1: dense.append(swish) self.sequential = nn.Sequential(dense, name="readout")
def __call__(self, x): h = self.sequential(x) return h