Source code for apax.layers.ntk_linear
from typing import Any
import flax.linen as nn
import jax.numpy as jnp
[docs]
class NTKLinear(nn.Module):
"""Linear layer with activation.
Corresponds to an NTK layer with "normal" and "zeros" for w and b initialization
and "use_ntk" set to True.
"""
units: int
w_init: str = "normal"
b_init: str = "zeros"
use_ntk: bool = True
dtype: Any = jnp.float32
@nn.compact
def __call__(self, inputs):
inputs = inputs.astype(self.dtype)
if self.w_init == "normal":
w_initializer = nn.initializers.normal(1.0, dtype=self.dtype)
elif self.w_init == "lecun":
w_initializer = nn.initializers.lecun_normal(dtype=self.dtype)
else:
raise ValueError(f"Unknown weight initializer: {self.w_init}.")
if self.b_init == "normal":
b_initializer = nn.initializers.normal(1.0, dtype=self.dtype)
elif self.b_init == "zeros":
b_initializer = nn.initializers.constant(0.0, dtype=self.dtype)
else:
raise ValueError(f"Unknown bias initializer: {self.b_init}.")
w = self.param("w", w_initializer, (inputs.shape[0], self.units), self.dtype)
b = self.param("b", b_initializer, [self.units], self.dtype)
wx = jnp.dot(inputs, w)
if self.use_ntk:
bias_factor = 0.1
weight_factor = jnp.sqrt(1.0 / inputs.shape[0])
prediction = weight_factor * wx + bias_factor * b
else:
prediction = wx + b
assert prediction.dtype == self.dtype
return prediction