Source code for apax.bal.transforms
import jax
import jax.numpy as jnp
from apax.bal.feature_maps import FeatureMap
[docs]
def ensemble_features(feature_fn: FeatureMap) -> FeatureMap:
"""
Feature map transformation which averages the kernels of a model ensemble.
"""
ensemble_feature_fn = jax.vmap(feature_fn, (0, None), 0)
def averaged_feature_fn(params, x):
g = ensemble_feature_fn(params, x)
if len(g.shape) != 2:
# models, features
raise ValueError(
"Dimension mismatch for input features. Expected shape (models,"
f" features), got {g.shape}"
)
n_models = g.shape[0]
# sqrt since the kernel is K = g^T g
feature_scale_factor = jnp.sqrt(1 / n_models)
g_ens = feature_scale_factor * jnp.sum(g, axis=0) # shape: n_features
return g_ens
return averaged_feature_fn
[docs]
def batch_features(feature_fn: FeatureMap) -> FeatureMap:
"""
Vectorizes a feature map over structures.
Should be the last transformation applied to a feature map.
"""
batched_feature_fn = jax.vmap(feature_fn, (None, 0), 0)
return batched_feature_fn