Skip to content

augernet.feature_assembly

Feature Assembly Module

Provides runtime feature selection and assembly for GNN training.

During data preparation, ALL possible node features are computed and stored as separate attributes on each PyG Data object. At training time, the user selects which features to include via a list of integer keys, and this module concatenates and scales them into data.x.

Feature Catalog

Key Name Dim Description ─── ────────────── ──── ───────────────────────────────────────── 0 skipatom_200 200 SkipAtom atom-type embedding (200-dim) 1 skipatom_30 30 SkipAtom atom-type embedding (30-dim) 2 onehot 5 Element one-hot encoding (H, C, N, O, F) 3 atomic_be 1 Isolated-atom 1s BE (Hartree, raw) 4 mol_be 1 Molecular CEBE for C, atomic for others (Hartree, raw) 5 e_score 1 Electronegativity-difference score (raw) 6 env_onehot 36 Carbon environment one-hot (NUM_CARBON_CATEGORIES) 7 morgan_fp 256 Per-atom Morgan fingerprint (ECFP2, radius=1)

Only the category_feature ([1,0,0], [0,1,0], [0,0,1]) is placed in data.x at preparation time. Everything else lives in data.<name> attributes and is assembled here at training time.

Usage

from augernet.feature_assembly import assemble_node_features, parse_feature_keys feature_keys_parsed = parse_feature_keys('035') # [0, 3, 5]

Before creating DataLoader — modifies data.x in-place

for data in data_list: ... assemble_node_features(data, feature_keys_parsed)

assemble_dataset(data_list, feature_keys, norm_stats=None)

Apply assemble_node_features to every graph in a list (in-place).

Parameters:

Name Type Description Default
data_list list

List of PyG Data objects.

required
feature_keys sequence of int

Which features to include.

required
norm_stats dict

Dataset-wide CEBE normalisation stats forwarded to assemble_node_features for mol_be scaling.

None
Returns
required
Source code in src/augernet/feature_assembly.py
def assemble_dataset(
    data_list: list,
    feature_keys: Sequence[int],
    norm_stats: Optional[Dict[str, float]] = None,
) -> list:
    """
    Apply ``assemble_node_features`` to every graph in a list (in-place).

    Parameters
    ----------
    data_list : list
        List of PyG Data objects.
    feature_keys : sequence of int
        Which features to include.
    norm_stats : dict, optional
        Dataset-wide CEBE normalisation stats forwarded to
        ``assemble_node_features`` for ``mol_be`` scaling.

    Returns the same list for convenience.
    """
    for data in data_list:
        assemble_node_features(data, feature_keys, inplace=True,
                               norm_stats=norm_stats)
    return data_list

assemble_node_features(data, feature_keys, inplace=True, norm_stats=None)

Concatenate selected node features into data.x.

The existing data.x (category_feature, shape [N, 3]) is kept as the first columns. Selected features are scaled and appended.

Parameters:

Name Type Description Default
data Data

A single graph. Must have feature attributes set during preparation.

required
feature_keys sequence of int

Which features to include (see FEATURE_NAMES).

required
inplace bool

If True, modifies data.x directly. If False, returns a copy.

True
norm_stats dict

{'mean': float, 'std': float} — dataset-wide CEBE normalisation statistics (eV). When provided and key 4 (mol_be) is in feature_keys, the mol_be feature is scaled using these dataset-wide stats instead of per-graph z-scoring.

None

Returns:

Name Type Description
data the (possibly modified) Data object.
Source code in src/augernet/feature_assembly.py
def assemble_node_features(
    data,
    feature_keys: Sequence[int],
    inplace: bool = True,
    norm_stats: Optional[Dict[str, float]] = None,
):
    """
    Concatenate selected node features into ``data.x``.

    The existing ``data.x`` (category_feature, shape [N, 3]) is kept as the
    **first** columns.  Selected features are scaled and appended.

    Parameters
    ----------
    data : torch_geometric.data.Data
        A single graph.  Must have feature attributes set during preparation.
    feature_keys : sequence of int
        Which features to include (see FEATURE_NAMES).
    inplace : bool
        If True, modifies ``data.x`` directly.  If False, returns a copy.
    norm_stats : dict, optional
        ``{'mean': float, 'std': float}`` — dataset-wide CEBE normalisation
        statistics (eV).  When provided and key 4 (``mol_be``) is in
        *feature_keys*, the ``mol_be`` feature is scaled using these
        dataset-wide stats instead of per-graph z-scoring.

    Returns
    -------
    data : the (possibly modified) Data object.
    """
    if not inplace:
        data = copy(data)

    import torch  # noqa: F811 — lazy import to keep module importable without torch

    # On first call, stash the original category_feature so that
    # subsequent calls (e.g. param search with different feature_keys)
    # always start from the base columns, not previously assembled ones.
    if not hasattr(data, '_category_feature'):
        data._category_feature = data.x.clone()

    parts = [data._category_feature]

    # Features that should NOT be scaled (categorical / pre-normalized)
    no_scale_keys = {0, 1, 2, 6, 7}

    for key in sorted(feature_keys):
        attr_name = FEATURE_NAMES[key]
        tensor = getattr(data, attr_name, None)
        if tensor is None:
            raise ValueError(
                f"Feature key {key} ({FEATURE_NAMES[key]}) not found on Data object."
            )

        # Ensure 2D
        if tensor.dim() == 1:
            tensor = tensor.unsqueeze(1)

        # Scale scalar features only
        if key in no_scale_keys:
            parts.append(tensor.float())
        else:
            parts.append(_scale_tensor(tensor.float()))

    data.x = torch.cat(parts, dim=1)
    return data

compute_feature_tag(feature_keys)

Compute a compact filename-safe tag from sorted feature keys.

compute_feature_tag([3, 0, 5]) '035'

Source code in src/augernet/feature_assembly.py
def compute_feature_tag(feature_keys: Sequence[int]) -> str:
    """
    Compute a compact filename-safe tag from sorted feature keys.

    >>> compute_feature_tag([3, 0, 5])
    '035'
    """
    return ''.join(str(k) for k in sorted(feature_keys))

describe_features(feature_keys)

Return a human-readable description of the selected feature set.

describe_features([0, 3, 5]) 'skipatom_200 (200) + atomic_be (1) + e_score (1)'

Source code in src/augernet/feature_assembly.py
def describe_features(feature_keys: Sequence[int]) -> str:
    """
    Return a human-readable description of the selected feature set.

    >>> describe_features([0, 3, 5])
    'skipatom_200 (200) + atomic_be (1) + e_score (1)'
    """
    parts = []
    for key in sorted(feature_keys):
        name = FEATURE_NAMES.get(key, f'unknown_{key}')
        parts.append(name)
    return ' + '.join(parts)

get_feature_dim(data, feature_keys)

Compute the total node-feature dimension that assemble_node_features will produce (category_feature columns + selected feature columns).

Parameters:

Name Type Description Default
data Data

A single graph from the dataset (used to read tensor shapes).

required
feature_keys sequence of int

Feature keys to include.

required

Returns:

Type Description
int

Total data.x width after assembly.

Source code in src/augernet/feature_assembly.py
def get_feature_dim(data, feature_keys: Sequence[int]) -> int:
    """
    Compute the total node-feature dimension that ``assemble_node_features``
    will produce (category_feature columns + selected feature columns).

    Parameters
    ----------
    data : torch_geometric.data.Data
        A single graph from the dataset (used to read tensor shapes).
    feature_keys : sequence of int
        Feature keys to include.

    Returns
    -------
    int
        Total ``data.x`` width after assembly.
    """
    # Use stashed category_feature if available (after first assembly),
    # otherwise fall back to current data.x (before first assembly).
    base = getattr(data, '_category_feature', data.x)
    cat_dim = base.size(1) if base is not None else 0
    feat_dim = 0
    for key in feature_keys:
        attr_name = FEATURE_NAMES[key]
        tensor = getattr(data, attr_name, None)
        if tensor is None:
            raise ValueError(
                f"Feature key {key} ({FEATURE_NAMES[key]}) not found on Data object. "
                f"Available attributes: {list(FEATURE_NAMES.values())}"
            )
        if tensor.dim() == 1:
            feat_dim += 1
        else:
            feat_dim += tensor.size(1)
    return cat_dim + feat_dim

parse_feature_keys(tag)

Parse a compact feature-key string into a sorted list of ints.

Each character in the string is one feature key digit.

parse_feature_keys('035') [0, 3, 5] parse_feature_keys('7') [7]

Source code in src/augernet/feature_assembly.py
def parse_feature_keys(tag: str) -> List[int]:
    """
    Parse a compact feature-key string into a sorted list of ints.

    Each character in the string is one feature key digit.

    >>> parse_feature_keys('035')
    [0, 3, 5]
    >>> parse_feature_keys('7')
    [7]
    """
    tag = str(tag).strip()
    if not tag:
        return []
    keys = sorted(int(ch) for ch in tag)
    unknown = [k for k in keys if k not in FEATURE_NAMES]
    if unknown:
        raise ValueError(
            f"Unknown feature key(s) {unknown} in '{tag}'. "
            f"Valid keys: {sorted(FEATURE_NAMES.keys())}"
        )
    return keys