Source code for prxteinmpnn.model.dense

"""Dense layer implementation for ProteinMPNN."""

import jax
import jax.numpy as jnp

from prxteinmpnn.utils.gelu import GeLU
from prxteinmpnn.utils.types import (
  ModelParameters,
  NodeFeatures,
)


[docs] @jax.jit def dense_layer(layer_parameters: ModelParameters, node_features: NodeFeatures) -> NodeFeatures: """Apply a dense layer to node features.""" ff_in_params = layer_parameters["dense_W_in"] ff_out_params = layer_parameters["dense_W_out"] w_in, b_in = ff_in_params["w"], ff_in_params["b"] w_out, b_out = ff_out_params["w"], ff_out_params["b"] return ( jnp.dot( GeLU( jnp.dot( node_features, w_in, ) + b_in, ), w_out, ) + b_out )