from typing import Callable, Literal, Tuple, Union
import jax
import jax.ad_checkpoint
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.traverse_util import flatten_dict, unflatten_dict
from pydantic import BaseModel, TypeAdapter
from apax.model.gmnn import EnergyModel
FeatureMap = Callable[[FrozenDict, dict], jax.Array]
[docs]
class LastLayerGradientFeatures(FeatureTransformation, extra="forbid"):
"""
Model transfomration which computes the gradient of the output
wrt. the specified layer.
https://arxiv.org/pdf/2203.09410
"""
name: Literal["ll_grad"] = "ll_grad"
layer_name: str = "dense_2"
def apply(self, model: EnergyModel) -> FeatureMap:
def ll_grad(params, inputs):
ll_params, remaining_params = extract_feature_params(params, self.layer_name)
def inner(ll_params):
ll_params.update(remaining_params)
full_params = unflatten_dict(ll_params)
# TODO find better abstraction for inputs
R, Z, idx, box, offsets = (
inputs["positions"],
inputs["numbers"],
inputs["idx"],
inputs["box"],
inputs["offsets"],
)
out = model.apply(full_params, R, Z, idx, box, offsets)
# take mean in case of shallow ensemble
# no effect for single model
out = jnp.mean(out)
return out
g_ll = jax.grad(inner)(ll_params)
g_ll = unflatten_dict(g_ll)
g_ll = jax.tree_map(lambda arr: jnp.mean(arr, axis=-1, keepdims=True), g_ll)
g_flat = jax.tree_map(lambda arr: jnp.reshape(arr, (-1,)), g_ll)
(gb, gw), _ = jax.tree_util.tree_flatten(g_flat)
g = [gw, gb]
g = jnp.concatenate(g)
return g
return ll_grad
[docs]
class LastLayerForceFeatures(FeatureTransformation, extra="forbid"):
"""
Model transfomration which computes the gradient of the output
wrt. the specified layer.
"""
name: Literal["ll_force_feat"] = "ll_force_feat"
layer_name: str = "dense_2"
return_raw: bool = True
def apply(self, model: EnergyModel) -> FeatureMap:
def ll_grad(params, inputs):
ll_params, remaining_params = extract_feature_params(params, self.layer_name)
energy_fn = lambda *inputs: jnp.mean(model.apply(*inputs))
force_fn = jax.grad(energy_fn, 1)
def inner(ll_params):
ll_params.update(remaining_params)
full_params = unflatten_dict(ll_params)
R, Z, idx, box, offsets = (
inputs["positions"],
inputs["numbers"],
inputs["idx"],
inputs["box"],
inputs["offsets"],
)
out = force_fn(full_params, R, Z, idx, box, offsets)
return out
ll_params = jax.tree_map(
lambda arr: jnp.mean(arr, axis=-1, keepdims=True), ll_params
)
g_ll = jax.jacfwd(inner)(ll_params)
g_ll = unflatten_dict(g_ll)
# shapes:
# b: n_atoms, 3, 1
# w: n_atoms, 3, n_features, 1
if self.return_raw:
(gb, gw), _ = jax.tree_util.tree_flatten(g_ll)
# g: n_atoms, 3, n_features
g = gw[:, :, :, 0]
else:
g_flat = jax.tree_map(
lambda arr: jnp.reshape(jnp.sum(jnp.sum(arr, 0), 0), (-1,)), g_ll
)
(gb, gw), _ = jax.tree_util.tree_flatten(g_flat)
g = [gw, gb]
g = jnp.concatenate(g)
return g
return ll_grad
[docs]
class IdentityFeatures(FeatureTransformation, extra="forbid"):
"""Identity feature map. For debugging purposes"""
name: Literal["identity"]
def apply(self, model: EnergyModel) -> FeatureMap:
return model.apply
FeatureMapOptions = TypeAdapter(
Union[LastLayerGradientFeatures, LastLayerForceFeatures, IdentityFeatures]
).validate_python