Source code for prxteinmpnn.utils.concatenate
"""Concatenation utilities.
prxteinmpnn.utils.concatenate
"""
import jax
import jax.numpy as jnp
from .types import EdgeFeatures, NeighborIndices, NodeFeatures
[docs]
@jax.jit
def concatenate_neighbor_nodes(
node_features: NodeFeatures,
edge_features: EdgeFeatures,
neighbor_indices: NeighborIndices,
) -> EdgeFeatures:
"""Concatenate node features with neighbor edge features.
Args:
node_features: (L, C_V) node features
edge_features: (L, K, C_E) edge features
neighbor_indices: (L, K) neighbor indices
Returns:
(L, K, C_V + C_E) concatenated features for neighbors
"""
neighbor_features = node_features[neighbor_indices] # (L, K, C_V)
return jnp.concatenate([edge_features, neighbor_features], axis=-1)