Source code for prxteinmpnn.model.projection
"""Final projection layer for the PrxteinMPNN model."""
import jax
import jax.numpy as jnp
from prxteinmpnn.utils.types import Logits, ModelParameters, NodeFeatures
[docs]
@jax.jit
def final_projection(
model_parameters: ModelParameters,
node_features: NodeFeatures,
) -> Logits:
"""Convert node features to logits.
Args:
model_parameters: Model parameters for the final projection.
node_features: Node features after the last MPNN layer.
Returns:
Logits: The final logits for the model.
"""
w_out, b_out = (
model_parameters["protein_mpnn/~/W_out"]["w"],
model_parameters["protein_mpnn/~/W_out"]["b"],
)
return jnp.dot(node_features, w_out) + b_out # + bias