projection

projection#

Final projection layer for the PrxteinMPNN model.

prxteinmpnn.model.projection.final_projection(model_parameters, node_features)[source]#

Convert node features to logits.

Parameters:
  • model_parameters (PyTree[str, 'P']) – Model parameters for the final projection.

  • node_features (Int[Array, 'num_atoms num_features']) – Node features after the last MPNN layer.

Returns:

The final logits for the model.

Return type:

Logits