Source code for apax.layers.ntk_linear
import flax.linen as nn
import jax.numpy as jnp
from apax.utils.convert import str_to_dtype
[docs]
class NTKLinear(nn.Module):
"""Linear layer with activation.
Corresponds to an NTK layer with "normal" and "zeros" for w and b initialization
and "use_ntk" set to True.
"""
units: int
w_init: str = "normal"
b_init: str = "zeros"
use_ntk: bool = True
dtype: str = "fp32"
@nn.compact
def __call__(self, inputs):
dtype = str_to_dtype(self.dtype)
inputs = inputs.astype(dtype)
if self.w_init == "normal":
w_initializer = nn.initializers.normal(1.0, dtype=dtype)
elif self.w_init == "lecun":
w_initializer = nn.initializers.lecun_normal(dtype=dtype)
else:
raise ValueError(f"Unknown weight initializer: {self.w_init}.")
if self.b_init == "normal":
b_initializer = nn.initializers.normal(1.0, dtype=dtype)
elif self.b_init == "zeros":
b_initializer = nn.initializers.constant(0.0, dtype=dtype)
else:
raise ValueError(f"Unknown bias initializer: {self.b_init}.")
w = self.param("w", w_initializer, (inputs.shape[0], self.units), dtype)
b = self.param("b", b_initializer, [self.units], dtype)
wx = jnp.dot(inputs, w)
if self.use_ntk:
bias_factor = 0.1
weight_factor = jnp.sqrt(1.0 / inputs.shape[0])
prediction = weight_factor * wx + bias_factor * b
else:
prediction = wx + b
assert prediction.dtype == dtype
return prediction