Source code for apax.cli.apax_app

import importlib.resources as pkg_resources
import json
import sys
from pathlib import Path

import typer
import yaml
from pydantic import ValidationError
from rich.console import Console

from apax.cli import templates

console = Console(highlight=False)

_TYPER_OPTS = {
    "pretty_exceptions_show_locals": False,
    "context_settings": {"help_option_names": ["-h", "--help"]},
}

app = typer.Typer(**_TYPER_OPTS)
validate_app = typer.Typer(**_TYPER_OPTS, help="Validate training or MD config files.")
template_app = typer.Typer(**_TYPER_OPTS, help="Create configuration file templates.")
schema_app = typer.Typer(
    **_TYPER_OPTS, help="Generate JSON schemata for train/md configs."
)
app.add_typer(validate_app, name="validate")
app.add_typer(template_app, name="template")
app.add_typer(schema_app, name="schema")


[docs] @app.command() def train( train_config_path: Path = typer.Argument( ..., help="Training configuration YAML file." ), log_level: str = typer.Option("info", help="Sets the training logging level."), ): """ Starts the training of a model with parameters provided by a configuration file. """ from apax.train.run import run run(train_config_path, log_level)
[docs] @app.command() def md( train_config_path: Path = typer.Argument( ..., help="Configuration YAML file that was used to train a model." ), md_config_path: Path = typer.Argument(..., help="MD configuration YAML file."), log_level: str = typer.Option("info", help="Sets the training logging level."), ): """ Starts performing a molecular dynamics simulation (currently only NHC thermostat) with parameters provided by a configuration file. """ from apax.md import run_md run_md(train_config_path, md_config_path, log_level)
[docs] @app.command() def eval( train_config_path: Path = typer.Argument( ..., help="Configuration YAML file that was used to train a model." ), n_data: int = typer.Option( -1, help=( "Number of test structures. (All structures are selected by not specifying" " it) Gets ignored if test_data_path is specified" ), ), ): """ Starts performing the evaluation of the test dataset with parameters provided by a configuration file. """ from apax.train.eval import eval_model eval_model(train_config_path, n_data)
[docs] @app.command() def docs(): """Opens the documentation website in your browser.""" console.print("Opening apax's docs at https://apax.readthedocs.io/en/latest/") typer.launch("https://apax.readthedocs.io/en/latest/")
# --- Schema commands --- def _handle_schema(config_cls, section, keywords, flat): """Shared logic for schema train/md commands.""" if flat: from apax.config.flat_schema import print_flat print_flat(config_cls) return from apax.config.schema_navigation import filter_schema, print_keywords schema = config_cls.model_json_schema() if keywords: error = print_keywords(schema, section) if error: console.print(error, style="red3") raise typer.Exit(code=1) return if section: node, error = filter_schema(schema, section) if error: console.print(error, style="red3") raise typer.Exit(code=1) schema = node print(json.dumps(schema, indent=2))
[docs] @schema_app.command("train") def schema_train( section: str = typer.Argument( None, help=( "Dotted path to filter (e.g. 'model', 'model.gmnn'," " 'model.gmnn.basis', 'optimizer'). Omit for the full schema." ), ), keywords: bool = typer.Option( False, "--keywords", help="Only list navigable subsection names at the given path.", ), flat: bool = typer.Option( False, "--flat", help="List all parameters as a flat table with path, type, default, and description.", ), ): """Print the training config JSON schema to stdout.""" from apax.config import Config _handle_schema(Config, section, keywords, flat)
[docs] @schema_app.command("md") def schema_md( section: str = typer.Argument( None, help=( "Dotted path to filter (e.g. 'ensemble', 'ensemble.nvt'," " 'constraints'). Omit for the full schema." ), ), keywords: bool = typer.Option( False, "--keywords", help="Only list navigable subsection names at the given path.", ), flat: bool = typer.Option( False, "--flat", help="List all parameters as a flat table with path, type, default, and description.", ), ): """Print the MD config JSON schema to stdout.""" from apax.config import MDConfig _handle_schema(MDConfig, section, keywords, flat)
[docs] @schema_app.command("vscode") def schema_vscode(): """Generate JSON schema files under .vscode/ for YAML autocompletion.""" console.print("Generating JSON schema") from apax.config import Config, MDConfig vscode_path = Path(".vscode") vscode_path.mkdir(exist_ok=True) settings_path = vscode_path / "settings.json" settings = json.loads(settings_path.read_text()) if settings_path.is_file() else {} settings.setdefault("yaml.schemas", {}) for name, cls, globs in [ ("apaxtrain", Config, ["train*.yaml"]), ("apaxmd", MDConfig, ["md*.yaml"]), ]: schema_path = vscode_path / f"{name}.schema.json" settings["yaml.schemas"][schema_path.resolve().as_posix()] = globs schema_path.write_text(json.dumps(cls.model_json_schema(), indent=2)) settings_path.write_text(json.dumps(settings, indent=2))
# --- Validation commands --- def _format_errors(e): """Format pydantic ValidationError into readable lines.""" parts = [] for error in e.errors(): loc = ".".join(str(x) for x in error["loc"]) msg = f"{loc}\n {error['msg']}\n input_type: {type(error['input']).__name__}" if type(error["input"]) not in (dict, list): msg += f"\n input: {error['input']}" parts.append(msg + "\n") return "".join(parts) def _validate_config(config_cls, config_path, user_config, label): """Validate a config dict and print result.""" try: config_cls.model_validate(user_config) except ValidationError as e: print(f"{e.error_count()} validation errors for config") print(_format_errors(e)) console.print("Configuration Invalid!", style="red3") raise typer.Exit(code=1) console.print("Success!", style="green3") console.print(f"{config_path} is a valid {label} config.") _TEMPLATE_DATA_DEFAULTS = { "directory": "/tmp/apax_validate", "experiment": "placeholder", "data_path": "/tmp/apax_validate/placeholder.extxyz", }
[docs] @validate_app.command("train") def validate_train_config( config_path: Path = typer.Argument( ..., help="Configuration YAML file to be validated." ), template: bool = typer.Option( False, "--template", help=( "Validate as a zntrack/ipsuite template. Fills in dummy values for" " directory, experiment, and data_path so configs that leave these" " to be set at runtime can still be validated." ), ), ): """Validates a training configuration file.""" from apax.config import Config with open(config_path, "r") as stream: user_config = yaml.safe_load(stream) if template: data = user_config.setdefault("data", {}) for key, default in _TEMPLATE_DATA_DEFAULTS.items(): if key not in data or data[key] is None: data[key] = default _validate_config(Config, config_path, user_config, "training")
[docs] @validate_app.command("md") def validate_md_config( config_path: Path = typer.Argument( ..., help="Configuration YAML file to be validated." ), ): """Validates a molecular dynamics configuration file.""" from apax.config import MDConfig with open(config_path, "r") as stream: user_config = yaml.safe_load(stream) _validate_config(MDConfig, config_path, user_config, "MD")
# --- Other commands ---
[docs] @app.command("visualize") def visualize_model( config_path: Path = typer.Argument( ..., help=( "Training configuration file to be visualized. A CO molecule is taken as" " sample input." ), ), ): """ Visualize a model based on a configuration file. A CO molecule is taken as sample input (influences number of atoms, number of species is set to 10). """ import jax from apax.config import Config from apax.utils.data import make_minimal_input with open(config_path, "r") as stream: user_config = yaml.safe_load(stream) try: config = Config.model_validate(user_config) except ValidationError as e: print(e) console.print("Configuration Invalid!", style="red3") raise typer.Exit(code=1) R, Z, idx, box, offsets = make_minimal_input() Builder = config.model.get_builder() builder = Builder(config.model.model_dump(), n_species=10) model = builder.build_energy_model() print(model.tabulate(jax.random.PRNGKey(0), R, Z, idx, box, offsets))
def _write_template(template_file, config_path): """Write a template file, exiting if one already exists.""" if Path(config_path).is_file(): console.print("There is already a config file in the working directory.") sys.exit(1) content = pkg_resources.read_text(templates, template_file) with open(config_path, "w") as f: f.write(content)
[docs] @template_app.command("train") def template_train_config( full: bool = typer.Option(False, help="Use all input options."), ): """Creates a training input template in the current working directory.""" if full: _write_template("train_config_full.yaml", "config_full.yaml") else: _write_template("train_config_minimal.yaml", "config.yaml")
[docs] @template_app.command("md") def template_md_config(): """Creates an MD input template in the current working directory.""" _write_template("md_config_minimal.yaml", "md_config.yaml")
[docs] def version_callback(value: bool) -> None: """Get the installed apax version.""" if value: from apax import __version__ console.print(f"apax {__version__}") raise typer.Exit()
@app.callback() def main( version: bool = typer.Option( None, "--version", "-V", callback=version_callback, is_eager=True ), ): _ = version