from typing import Callable, Literal, Tuple, Union
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import tree_util
from pydantic import BaseModel, TypeAdapter
from apax.nn.models import EnergyModel
FeatureMap = Callable[[FrozenDict, dict], jax.Array]
[docs]
class LastLayerGradientFeatures(FeatureTransformation, extra="forbid"):
"""
Model transfomration which computes the gradient of the output
wrt. the specified layer.
https://arxiv.org/pdf/2203.09410
Parameters
----------
layer_name: str
Name of the layer wrt. which to take the gradient.
"""
name: Literal["ll_grad"] = "ll_grad"
layer_name: str = "dense_2"
def apply(self, model: EnergyModel) -> FeatureMap:
def ll_grad(params, inputs):
ll_params, remaining_params = extract_feature_params(params, self.layer_name)
def inner(ll_params):
ll_params.update(remaining_params)
full_params = unflatten_dict(ll_params)
# TODO find better abstraction for inputs
R, Z, idx, box, offsets = (
inputs["positions"],
inputs["numbers"],
inputs["idx"],
inputs["box"],
inputs["offsets"],
)
out = model(full_params, R, Z, idx, box, offsets)
# take mean in case of shallow ensemble
# no effect for single model
out = jnp.mean(out)
return out
g_ll = jax.grad(inner)(ll_params)
g_ll = unflatten_dict(g_ll)
g_ll = tree_util.tree_map(
lambda arr: jnp.mean(arr, axis=-1, keepdims=True), g_ll
)
g_flat = tree_util.tree_map(lambda arr: jnp.reshape(arr, (-1,)), g_ll)
(gb, gw), _ = tree_util.tree_flatten(g_flat)
g = [gw, gb]
g = jnp.concatenate(g)
return g
return ll_grad
[docs]
class LastLayerForceFeatures(FeatureTransformation, extra="forbid"):
"""
Model transformation which computes the jacobian of the forces
wrt. the specified layer.
For BAL the strategy "flatten" has to be selected.
Parameters
----------
layer_name: str
Name of the layer wrt. which to take the jacobian.
strategy: str
one of raw, sum, flatten. Only flatten seems to work
for BAL. raw is required for LLPR.
"""
name: Literal["ll_force_feat"] = "ll_force_feat"
layer_name: str = "dense_2"
strategy: str = "raw"
def apply(self, model: EnergyModel) -> FeatureMap:
def ll_grad(params, inputs):
ll_params, remaining_params = extract_feature_params(params, self.layer_name)
energy_fn = lambda *inputs: jnp.mean(model(*inputs))
force_fn = jax.grad(energy_fn, 1)
def inner(ll_params):
ll_params.update(remaining_params)
full_params = unflatten_dict(ll_params)
R, Z, idx, box, offsets = (
inputs["positions"],
inputs["numbers"],
inputs["idx"],
inputs["box"],
inputs["offsets"],
)
out = force_fn(full_params, R, Z, idx, box, offsets)
return out
g_ll = jax.jacfwd(inner)(ll_params)
g_ll = unflatten_dict(g_ll)
# shapes:
# b: n_atoms, 3, 1
# w: n_atoms, 3, n_features, 1
if self.strategy == "raw":
(gb, gw), _ = tree_util.tree_flatten(g_ll)
# g: n_atoms, 3, n_features
g = gw[:, :, :, 0]
elif self.strategy == "sum":
g_summed = tree_util.tree_map(
lambda arr: jnp.reshape(jnp.sum(jnp.sum(arr, 0), 0), (-1,)), g_ll
)
(gb, gw), _ = tree_util.tree_flatten(g_summed)
g = [gw, gb]
g = jnp.concatenate(g)
elif self.strategy == "flatten":
g_flat = tree_util.tree_map(lambda arr: jnp.reshape(arr, (-1,)), g_ll)
(gb, gw), _ = tree_util.tree_flatten(g_flat)
g = gw
else:
raise ValueError(f"unknown strategy: {self.strategy}")
return g
return ll_grad
[docs]
class FullGradientRPFeatures(FeatureTransformation, extra="forbid"):
"""
Model transfomration which computes the gradient of the output
wrt. all parameters and applies a gaussian random projection for
dimensionality reduction.
https://arxiv.org/pdf/2203.09410
Parameters
----------
num_rp: int
Dimensionality to reduce the features to.
"""
name: Literal["full_grad_rp"] = "full_grad_rp"
num_rp: int = 512
def apply(self, model: EnergyModel) -> FeatureMap:
def full_grad(params, inputs):
def inner(params):
# TODO find better abstraction for inputs
R, Z, idx, box, offsets = (
inputs["positions"],
inputs["numbers"],
inputs["idx"],
inputs["box"],
inputs["offsets"],
)
out = model(params, R, Z, idx, box, offsets)
# take mean in case of shallow ensemble
# no effect for single model
out = jnp.mean(out)
return out
grads = jax.grad(inner)(params)
grads = tree_util.tree_map(
lambda arr: jnp.mean(arr, axis=-1, keepdims=True), grads
)
g_flat = tree_util.tree_map(lambda arr: jnp.reshape(arr, (-1,)), grads)
gs, _ = tree_util.tree_flatten(g_flat)
g = jnp.concatenate(gs)
with jax.ensure_compile_time_eval():
n_features = g.shape[0]
RP = np.random.randn(n_features, self.num_rp) / np.sqrt(self.num_rp)
RP = jnp.array(RP)
g_rp = g @ RP
return g_rp
return full_grad
[docs]
class IdentityFeatures(FeatureTransformation, extra="forbid"):
"""Identity feature map. For debugging purposes"""
name: Literal["identity"]
def apply(self, model: EnergyModel) -> FeatureMap:
return model
FeatureMapOptions = TypeAdapter(
Union[
LastLayerGradientFeatures,
LastLayerForceFeatures,
FullGradientRPFeatures,
IdentityFeatures,
]
).validate_python