Source code for apax.layers.initializers

from typing import Any

import jax.numpy as jnp
from jax import random
from jax._src import dtypes
from jax.nn.initializers import Initializer

Array = Any
KeyArray = Array

DTypeLikeFloat = Any
DTypeLikeComplex = Any
DTypeLikeInexact = Any
RealNumeric = Any


[docs] def uniform_range(minval, maxval, dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds an initializer that returns real uniformly-distributed random arrays in a specified range. """ def init(key: KeyArray, shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) return random.uniform(key, shape, dtype, minval, maxval) return init