Skip to content

augernet.train_driver

AugerNet Training Driver

Contains run_kfold_cv, run_param_search, _build_param_configs, and the mode-dispatch logic for the GNN and CNN training.

Model-specific behaviour is provided by the backend module: - augernet.backend_gnn (CEBE and Auger prediction GNN) - augernet.backend_cnn (bond environment classification CNN)

The backend exports hooks: load_data(cfg) : data dict train_single_run(data, …) : result dict (receives save_paths from driver) load_saved_model(save_paths, …): (model, device) or result dict run_evaluation(…) : eval metrics dict run_unit_tests(…) : None run_predict(…) : None

run(cfg)

Execute a full training / evaluation / prediction run.

Parameters:

Name Type Description Default
cfg AugerNetConfig: resolved configuration from yml.
required
Source code in src/augernet/train_driver.py
def run(cfg: AugerNetConfig):
    """
    Execute a full training / evaluation / prediction run.

    Parameters
    ----------
    cfg : AugerNetConfig: resolved configuration from yml.
    """
    mode = cfg.mode
    model_name = cfg.model

    print(f"\n{'=' * 80}")
    print(f"  AugerNet: model={model_name}  mode={mode}")
    if cfg.model_id:
        print(f"  Model ID: {cfg.model_id}")
    print(f"{'=' * 80}")

    be = _get_backend(cfg)

    # ── Modes that do NOT need the full training dataset ─────────────────
    if mode == 'predict':
        _run_predict(cfg)
        print("\n Predictions Complete.")
        return

    # ── Load data ────────────────────────────────────────────────────────
    data = be.load_data(cfg)

    # ── Dispatch ─────────────────────────────────────────────────────────
    result = None  # may be set by train/cv for unit tests

    if mode == 'cv':
        cv_summary = run_kfold_cv(data, cfg)
        # Load the best-fold model for unit tests
        if getattr(cfg, 'run_unit_tests', False):
            best_fold = cv_summary['best_fold']
            save_paths = _build_save_paths(cfg, best_fold, cfg.models_dir)
            result = be.load_saved_model(save_paths, data, cfg)

    elif mode == 'train':
        save_paths = _build_save_paths(cfg, cfg.train_fold, cfg.models_dir)
        result = be.train_single_run(
            data, cfg.train_fold, cfg.n_folds,
            save_paths=save_paths,
            output_dir=cfg.outputs_dir,
            cfg=cfg,
            verbose=True,
        )

        if cfg.run_evaluation:
            be.run_evaluation(
                result, data, cfg.train_fold,
                output_dir=cfg.outputs_dir,
                png_dir=cfg.pngs_dir, cfg=cfg,
                train_results=result.get('train_results'),
            )

    elif mode == 'param':
        run_param_search(data, cfg)

    elif mode == 'evaluate':
        _run_evaluate(data, cfg)

    else:
        raise ValueError(
            f"Unknown mode '{mode}'. "
            f"Choose from: cv, train, param, evaluate, predict"
        )

    # ── Unit tests ───────────────────────────────────────────────────────
    if getattr(cfg, 'run_unit_tests', False) and mode in ('train', 'cv'):
        if result is not None:
            try:
                be.run_unit_tests(result, data, cfg)
            except Exception:
                pass  # unit tests are optional

    print("\n AugerNet run complete\n")

run_kfold_cv(data, cfg)

Run full k-fold cross-validation.

Trains one model per fold via backend.train_single_run, saves each model, and writes a JSON summary identifying the best fold.

Source code in src/augernet/train_driver.py
def run_kfold_cv(data, cfg) -> Dict[str, Any]:
    """
    Run full k-fold cross-validation.

    Trains one model per fold via backend.train_single_run, saves each
    model, and writes a JSON summary identifying the best fold.
    """

    be = _get_backend(cfg)
    n_folds  = cfg.n_folds

    fold_results = []

    print(f"\n{'#' * 80}")
    print(f"#  K-FOLD CROSS-VALIDATION  ({n_folds} folds)")
    print(f"{'#' * 80}")

    for fold in range(1, n_folds + 1):
        save_paths = _build_save_paths(cfg, fold, cfg.models_dir)
        result = be.train_single_run(
            data, fold, n_folds,
            save_paths=save_paths,
            output_dir=cfg.outputs_dir,
            cfg=cfg,
            verbose=True,
        )

        # Save loss curve and run evaluation for this fold
        eval_metrics = None
        if cfg.run_evaluation:
            eval_metrics = be.run_evaluation(
                result, data, fold,
                output_dir=cfg.outputs_dir, png_dir=cfg.pngs_dir, cfg=cfg,
                train_results=result.get('train_results'),
                exp_split='val',  # CV uses validation subset only
            )

        # Build a JSON-serialisable record
        entry = _run_entry(result, eval_metrics=eval_metrics)
        entry['fold'] = fold
        fold_results.append(entry)

    # ── Identify best fold ───────────────────────────────────────────────
    best = min(fold_results, key=lambda r: r['best_val_loss'])

    # ── Print summary table ──────────────────────────────────────────────
    has_eval = any(r.get('eval_mae') is not None for r in fold_results)
    _print_cv_summary(fold_results, n_folds, best, has_eval=has_eval)

    combined = [r['best_val_loss'] for r in fold_results]
    print(f"\n  Mean Val Loss: {np.mean(combined):.6f} +/- {np.std(combined):.6f}")
    if has_eval:
        eval_maes = [r['eval_mae'] for r in fold_results if r.get('eval_mae') is not None]
        print(f"  Mean Exp MAE:  {np.mean(eval_maes):.4f} +/- {np.std(eval_maes):.4f} eV")
    print(f"  Best fold: Fold {best['fold']}  (loss={best['best_val_loss']:.6f})")

    # ── Save JSON summary ────────────────────────────────────────────────
    cv_summary = _build_summary(fold_results, cfg)
    cv_summary['n_folds'] = n_folds
    cv_summary['best_fold'] = best['fold']

    summary_path = os.path.join(cfg.result_dir, f'{cfg.model_id}_cv_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(cv_summary, f, indent=2, default=str)
    print(f"\nCV summary saved to: {summary_path}")

    return cv_summary

Run hyperparameter search.

For each combination in cfg.param_grid, trains one fold via backend.train_single_run with overrides, records the best validation loss, and writes a sorted leaderboard JSON.

Source code in src/augernet/train_driver.py
def run_param_search(data, cfg) -> Dict[str, Any]:
    """
    Run hyperparameter search.

    For each combination in ``cfg.param_grid``, trains one fold via
    ``backend.train_single_run`` with overrides, records the best
    validation loss, and writes a sorted leaderboard JSON.
    """
    be = _get_backend(cfg)

    param_grid = cfg.param_grid
    if not param_grid:
        raise ValueError(
            "No param_grid defined in config. "
            "Add a 'param_grid' section to your YAML file."
        )

    fold      = cfg.train_fold
    n_folds   = cfg.n_folds

    configs = _build_param_configs(param_grid)
    n_configs = len(configs)

    # Build a unique search identifier for this grid
    search_id = _param_search_id(param_grid)

    # Cap epochs for search speed
    search_epochs  = min(cfg.num_epochs, 300)
    search_patience = min(cfg.patience, 40)

    print(f"\n{'#' * 80}")
    print(f"#  HYPERPARAMETER SEARCH  ({n_configs} configurations)")
    print(f"#  Fold {fold}/{n_folds}  |  max {search_epochs} epochs  |  patience {search_patience}")
    print(f"{'#' * 80}")

    print(f"\nSearch grid:")
    for k, v in sorted(param_grid.items()):
        print(f"  {k}: {v}")
    print()

    results = []
    t0_total = time.time()

    for i, config in enumerate(configs):
        config_id = f"cfg{i:03d}"

        # Overrides for this config (includes capped epochs)
        overrides = dict(config)
        overrides['num_epochs'] = search_epochs
        overrides['patience'] = search_patience

        print(f"\n{'─' * 70}")
        print(f"  Config {i+1}/{n_configs}  [{config_id}]")
        for k, v in sorted(config.items()):
            print(f"    {k}: {v}")
        print(f"{'─' * 70}")

        t0 = time.time()
        try:
            save_paths = _build_save_paths(
                cfg, fold, cfg.models_dir,
                prefix=search_id, config_id=config_id,
            )
            result = be.train_single_run(
                data, fold, n_folds,
                save_paths=save_paths,
                output_dir=cfg.models_dir,
                cfg=cfg,
                verbose=True,
                **overrides,
            )
            elapsed = time.time() - t0

            # Save loss curve and run evaluation for this fold
            eval_metrics = None
            if cfg.run_evaluation:
                eval_metrics = be.run_evaluation(
                    result, data, fold,
                    output_dir=cfg.outputs_dir, png_dir=cfg.pngs_dir, cfg=cfg,
                    train_results=result.get('train_results'),
                    config_id=config_id,
                    param_file_prefix=search_id,
                    exp_split='val',  # param search uses validation subset only
                )

            entry = _run_entry(result, eval_metrics=eval_metrics)
            entry.update({
                'config_id': config_id,
                'rank': 0,
                **config,
                'elapsed_sec': round(elapsed, 1),
                'status': 'ok',
            })

        except Exception as e:
            elapsed = time.time() - t0
            entry = {
                'model_id': cfg.model_id,
                'best_val_loss': float('inf'),
                'best_train_loss': None,
                'best_val_epoch': None,
                'n_epochs': 0,
                'model_path': None,
                'final_train_loss': None,
                'final_val_loss': None,
                'config_id': config_id,
                'rank': 999,
                **config,
                'elapsed_sec': round(elapsed, 1),
                'status': f'error: {e}',
            }
            print(f"ERROR: {e}")

        results.append(entry)

    total_elapsed = time.time() - t0_total

    # Sort by best_val_loss
    results.sort(key=lambda r: r['best_val_loss'])
    for rank, r in enumerate(results):
        r['rank'] = rank + 1

    # ── Leaderboard ──────────────────────────────────────────────────────
    has_eval = any(r.get('eval_mae') is not None for r in results)
    _print_param_leaderboard(results, n_configs, total_elapsed, param_grid,
                             has_eval=has_eval)

    best = results[0]
    print(f"\n  Best config: {best['config_id']}")
    for k in sorted(param_grid.keys()):
        print(f"      {k}: {best.get(k)}")
    print(f"      val_loss: {best['best_val_loss']:.6f}")
    if has_eval and best.get('eval_mae') is not None:
        print(f"      exp_mae:  {best['eval_mae']:.4f} eV")

    # ── Save JSON summary ────────────────────────────────────────────────
    summary = _build_summary(results, cfg)
    summary['search_id'] = search_id
    summary['n_configs'] = n_configs
    summary['search_epochs'] = search_epochs
    summary['search_patience'] = search_patience
    summary['total_elapsed_min'] = round(total_elapsed / 60, 1)
    summary['param_grid'] = {
        k: [str(v) if isinstance(v, float) else v for v in vals]
        for k, vals in param_grid.items()
    }
    summary['best_config_id'] = best['config_id']
    summary['best_params'] = {k: best.get(k)
                              for k in sorted(param_grid.keys())}

    summary_path = os.path.join(cfg.result_dir,
                                f'{search_id}_{cfg.model_id}_param_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2, default=str)
    print(f"\nSaved param search summary to: {summary_path}")

    return summary