import importlib.metadata
import importlib.resources as pkg_resources
import json
import sys
from pathlib import Path
import typer
import yaml
from pydantic import ValidationError
from rich.console import Console
from apax.cli import templates
console = Console(highlight=False)
app = typer.Typer(
context_settings={"help_option_names": ["-h", "--help"]},
pretty_exceptions_show_locals=False,
)
validate_app = typer.Typer(
pretty_exceptions_show_locals=False,
context_settings={"help_option_names": ["-h", "--help"]},
help="Validate training or MD config files.",
)
template_app = typer.Typer(
pretty_exceptions_show_locals=False,
context_settings={"help_option_names": ["-h", "--help"]},
help="Create configuration file templates.",
)
app.add_typer(validate_app, name="validate")
app.add_typer(template_app, name="template")
[docs]
@app.command()
def train(
train_config_path: Path = typer.Argument(
..., help="Training configuration YAML file."
),
log_level: str = typer.Option("info", help="Sets the training logging level."),
):
"""
Starts the training of a model with parameters provided by a configuration file.
"""
from apax.train.run import run
run(train_config_path, log_level)
[docs]
@app.command()
def md(
train_config_path: Path = typer.Argument(
..., help="Configuration YAML file that was used to train a model."
),
md_config_path: Path = typer.Argument(..., help="MD configuration YAML file."),
log_level: str = typer.Option("info", help="Sets the training logging level."),
):
"""
Starts performing a molecular dynamics simulation (currently only NHC thermostat)
with parameters provided by a configuration file.
"""
from apax.md import run_md
run_md(train_config_path, md_config_path, log_level)
[docs]
@app.command()
def eval(
train_config_path: Path = typer.Argument(
..., help="Configuration YAML file that was used to train a model."
),
n_data: int = typer.Option(
-1,
help=(
"Number of test structures. (All structures are selected by not specifying"
" it) Gets ignored if test_data_path is specified"
),
),
):
"""
Starts performing the evaluation of the test dataset
with parameters provided by a configuration file.
"""
from apax.train.eval import eval_model
eval_model(train_config_path, n_data)
[docs]
@app.command()
def docs():
"""
Opens the documentation website in your browser.
"""
console.print("Opening apax's docs at https://apax.readthedocs.io/en/latest/")
typer.launch("https://apax.readthedocs.io/en/latest/")
[docs]
@app.command()
def schema():
"""
Generating JSON schemata for autocompletion of train/md inputs in VSCode.
"""
console.print("Generating JSON schema")
from apax.config import Config, MDConfig
train_schema = Config.model_json_schema()
md_schema = MDConfig.model_json_schema()
with open("./apaxtrain.schema.json", "w") as f:
f.write(json.dumps(train_schema, indent=2))
with open("./apaxmd.schema.json", "w") as f:
f.write(json.dumps(md_schema, indent=2))
[docs]
@validate_app.command("train")
def validate_train_config(
config_path: Path = typer.Argument(
..., help="Configuration YAML file to be validated."
),
):
"""
Validates a training configuration file.
Parameters
----------
config_path: Path to the training configuration file.
"""
from apax.config import Config
with open(config_path, "r") as stream:
user_config = yaml.safe_load(stream)
try:
_ = Config.model_validate(user_config)
except ValidationError as e:
print(e)
console.print("Configuration Invalid!", style="red3")
raise typer.Exit(code=1)
else:
console.print("Success!", style="green3")
console.print(f"{config_path} is a valid training config.")
[docs]
@validate_app.command("md")
def validate_md_config(
config_path: Path = typer.Argument(
..., help="Configuration YAML file to be validated."
),
):
"""
Validates a molecular dynamics configuration file.
Parameters
----------
config_path: Path to the molecular dynamics configuration file.
"""
from apax.config import MDConfig
with open(config_path, "r") as stream:
user_config = yaml.safe_load(stream)
try:
_ = MDConfig.model_validate(user_config)
except ValidationError as e:
print(e)
console.print("Configuration Invalid!", style="red3")
raise typer.Exit(code=1)
else:
console.print("Success!", style="green3")
console.print(f"{config_path} is a valid MD config.")
[docs]
@app.command("visualize")
def visualize_model(
config_path: Path = typer.Argument(
...,
help=(
"Training configuration file to be visualized. A CO molecule is taken as"
" sample input."
),
),
):
"""
Visualize a model based on a configuration file.
A CO molecule is taken as sample input (influences number of atoms,
number of species is set to 10).
Parameters
----------
config_path: Path to the training configuration file.
"""
import jax
from apax.config import Config
from apax.model.builder import ModelBuilder
from apax.utils.data import make_minimal_input
with open(config_path, "r") as stream:
user_config = yaml.safe_load(stream)
try:
config = Config.model_validate(user_config)
except ValidationError as e:
print(e)
console.print("Configuration Invalid!", style="red3")
raise typer.Exit(code=1)
R, Z, idx, box, offsets = make_minimal_input()
builder = ModelBuilder(config.model.get_dict(), n_species=10)
model = builder.build_energy_model()
print(model.tabulate(jax.random.PRNGKey(0), R, Z, idx, box, offsets))
[docs]
@template_app.command("train")
def template_train_config(
full: bool = typer.Option(False, help="Use all input options."),
):
"""
Creates a training input template in the current working directory.
"""
if full:
template_file = "train_config_full.yaml"
config_path = "config_full.yaml"
else:
template_file = "train_config_minimal.yaml"
config_path = "config.yaml"
template_content = pkg_resources.read_text(templates, template_file)
if Path(config_path).is_file():
console.print("There is already a config file in the working directory.")
sys.exit(1)
else:
with open(config_path, "w") as config:
config.write(template_content)
[docs]
@template_app.command("md")
def template_md_config():
"""
Creates a training input template in the current working directory.
"""
template_file = "md_config_minimal.yaml"
config_path = "md_config.yaml"
template_content = pkg_resources.read_text(templates, template_file)
if Path(config_path).is_file():
console.print("There is already a config file in the working directory.")
sys.exit(1)
else:
with open(config_path, "w") as config:
config.write(template_content)
[docs]
def version_callback(value: bool) -> None:
"""Get the installed apax version."""
if value:
console.print(f"apax {importlib.metadata.version('apax')}")
raise typer.Exit()
@app.callback()
def main(
version: bool = typer.Option(
None, "--version", "-V", callback=version_callback, is_eager=True
),
):
# Taken from https://github.com/zincware/dask4dvc/blob/main/dask4dvc/cli/main.py
_ = version