Source code for apax.data.initialization

import logging

import numpy as np

from apax.utils.data import load_data, split_atoms, split_idxs

log = logging.getLogger(__name__)


[docs] def load_data_files(data_config): """ Load data files for training and validation. Parameters ---------- data_config : object Data configuration object. Returns ------- Tuple Tuple containing list of ase.Atoms objects for training and validation. """ log.info("Running Input Pipeline") if data_config.data_path is not None: log.info(f"Read data file {data_config.data_path}") atoms_list = load_data(data_config.data_path) train_idxs, val_idxs = split_idxs( atoms_list, data_config.n_train, data_config.n_valid ) train_atoms_list, val_atoms_list = split_atoms(atoms_list, train_idxs, val_idxs) np.savez( data_config.model_version_path / "train_val_idxs", train_idxs=train_idxs, val_idxs=val_idxs, ) elif data_config.train_data_path and data_config.val_data_path is not None: log.info(f"Read training data file {data_config.train_data_path}") log.info(f"Read validation data file {data_config.val_data_path}") train_atoms_list = load_data(data_config.train_data_path) val_atoms_list = load_data(data_config.val_data_path) else: raise ValueError("input data path/paths not defined") return train_atoms_list, val_atoms_list