projection#
Final projection layer for the PrxteinMPNN model.
- prxteinmpnn.model.projection.final_projection(model_parameters, node_features)[source]#
Convert node features to logits.
- Parameters:
model_parameters (
PyTree[str, 'P']
) – Model parameters for the final projection.node_features (
Int[Array, 'num_atoms num_features']
) – Node features after the last MPNN layer.
- Returns:
The final logits for the model.
- Return type:
Logits