Source code for apax.data.statistics

import dataclasses
import logging

import numpy as np

log = logging.getLogger(__name__)


[docs] @dataclasses.dataclass class DatasetStats: elemental_shift: np.array = None elemental_scale: float = None n_species: int = 119
class PerElementRegressionShift: name = "per_element_regression_shift" parameters = ["energy_regularisation"] dtypes = [float] @staticmethod def compute(inputs, labels, shift_options) -> np.ndarray: log.info("Computing per element energy regression.") lambd = shift_options["energy_regularisation"] energies = labels["energy"] numbers = inputs["numbers"] system_sizes = inputs["n_atoms"] energies = np.array(energies) system_sizes = np.array(system_sizes) ds_energy = np.sum(energies) n_atoms_total = np.sum(system_sizes) mean_energy = ds_energy / n_atoms_total n_species = 119 # for simplicity, we assume any element could be in the dataset X = np.zeros(shape=(energies.shape[0], n_species)) y = np.zeros(energies.shape[0]) for i in range(energies.shape[0]): Z_counts = np.zeros(n_species) for z in set(numbers[i]): Z_counts[z] = np.count_nonzero(numbers[i] == z) X[i] = Z_counts E_sub_mean = energies[i] - mean_energy * numbers[i].shape[0] y[i] = E_sub_mean XTX = X.T @ X reg_term = lambd * np.eye(XTX.shape[0]) result = np.linalg.lstsq(XTX + reg_term, X.T @ y, rcond=-1) elemental_energies_shift = result[0] elemental_energies_shift += mean_energy return elemental_energies_shift class IsolatedAtomEnergyShift: name = "isolated_atom_energy_shift" parameters = ["E0s"] dtypes = [dict[int, float]] @staticmethod def compute(inputs, labels, shift_options): shift_options = shift_options["E0s"] n_species = 119 elemental_energies_shift = np.zeros(n_species) for k, v in shift_options.items(): elemental_energies_shift[k] = v return elemental_energies_shift class MeanEnergyRMSScale: name = "mean_energy_rms_scale" parameters = [] dtypes = [] @staticmethod def compute(inputs, labels, scale_options): # log.info("Computing per element energy regression.") energies = labels["energy"] numbers = inputs["numbers"] system_sizes = inputs["n_atoms"] energies = np.array(energies) system_sizes = np.array(system_sizes) ds_energy = np.sum(energies) n_atoms_total = np.sum(system_sizes) mean_energy = ds_energy / n_atoms_total mean_err_sse = 0.0 for i in range(energies.shape[0]): E_sub_mean = energies[i] - mean_energy * numbers[i].shape[0] mean_err_sse += E_sub_mean**2 / numbers[i].shape[0] energy_std = np.sqrt(mean_err_sse / n_atoms_total) return energy_std class PerElementForceRMSScale: name = "per_element_force_rms_scale" parameters = [] dtypes = [] @staticmethod def compute(inputs, labels, scale_options): n_species = 119 forces = np.concatenate(labels["forces"], axis=0) numbers = np.concatenate(inputs["numbers"], axis=0) elements = np.unique(numbers) element_scale = np.ones(n_species) for element in elements: element_forces = forces[numbers == element] force_rms = np.sqrt(np.mean(np.linalg.norm(element_forces, 2, axis=1))) element_scale[element] = force_rms return element_scale class GlobalCustomScale: name = "global_custom_scale" parameters = ["factor"] dtypes = [float] @staticmethod def compute(inputs, labels, scale_options): element_scale = scale_options["factor"] return element_scale class PerElementCustomScale: name = "per_element_custom_scale" parameters = ["factors"] dtypes = [dict[int, float]] @staticmethod def compute(inputs, labels, scale_options): n_species = 119 element_scale = np.ones(n_species) for k, v in scale_options["factors"].items(): element_scale[k] = v return element_scale shift_method_list = [PerElementRegressionShift, IsolatedAtomEnergyShift] scale_method_list = [ MeanEnergyRMSScale, PerElementForceRMSScale, GlobalCustomScale, PerElementCustomScale, ] def compute_scale_shift_parameters( inputs, labels, shift_method, scale_method, shift_options, scale_options ): shift_methods = {method.name: method for method in shift_method_list} scale_methods = {method.name: method for method in scale_method_list} if shift_method not in shift_methods.keys(): raise KeyError( f"The shift method '{shift_method}' is not among the implemented methods." f" Choose from {shift_method.keys()}" ) if scale_method not in scale_methods.keys(): raise KeyError( f"The scale method '{scale_method}' is not among the implemented methods." f" Choose from {scale_method.keys()}" ) shift_method = shift_methods[shift_method] scale_method = scale_methods[scale_method] shift_parameters = shift_method.compute(inputs, labels, shift_options) scale_parameters = scale_method.compute(inputs, labels, scale_options) ds_stats = DatasetStats(shift_parameters, scale_parameters) return ds_stats