"""Decoder module for the PrxteinMPNN model."""
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING
import jax
import jax.numpy as jnp
from jaxtyping import PRNGKeyArray
if TYPE_CHECKING:
from jaxtyping import Array
from prxteinmpnn.utils.gelu import GeLU
from prxteinmpnn.utils.normalize import layer_normalization
from prxteinmpnn.utils.types import (
AtomMask,
AttentionMask,
AutoRegressiveMask,
EdgeFeatures,
Logits,
Message,
ModelParameters,
NeighborIndices,
NodeEdgeFeatures,
NodeFeatures,
OneHotProteinSequence,
ProteinSequence,
SequenceEdgeFeatures,
)
from .projection import final_projection
if TYPE_CHECKING:
from jaxtyping import Int
import enum
from prxteinmpnn.utils.concatenate import concatenate_neighbor_nodes
from .dense import dense_layer
from .masked_attention import MaskedAttentionEnum, mask_attention
[docs]
class DecodingEnum(enum.Enum):
"""Enum for different types of decoders."""
CONDITIONAL = "conditional"
UNCONDITIONAL = "unconditional"
AUTOREGRESSIVE = "autoregressive"
DecodeMessageInputs = tuple[
NodeFeatures,
EdgeFeatures,
ModelParameters,
]
DecodeMessageFn = Callable[[*DecodeMessageInputs], Message]
DecoderNormalizeInputs = tuple[
Message,
NodeFeatures,
AtomMask,
ModelParameters,
float,
]
DecoderNormalizeFn = Callable[[*DecoderNormalizeInputs], NodeFeatures]
MaskedAttentionDecoderInputs = tuple[
NodeFeatures,
EdgeFeatures,
AtomMask,
AttentionMask,
ModelParameters,
float,
]
MaskedAttentionDecoderFn = Callable[[*MaskedAttentionDecoderInputs], NodeFeatures]
DecoderInputs = tuple[
NodeFeatures,
EdgeFeatures,
AtomMask,
ModelParameters,
float,
]
DecoderFn = Callable[[*DecoderInputs], NodeFeatures]
RunDecoderInputs = tuple[NodeFeatures, EdgeFeatures, AtomMask]
RunDecoderFn = Callable[[*RunDecoderInputs], NodeFeatures]
RunMaskedAttentionDecoderInputs = tuple[
NodeFeatures,
EdgeFeatures,
AtomMask,
AttentionMask,
]
RunMaskedAttentionDecoderFn = Callable[[*RunMaskedAttentionDecoderInputs], NodeFeatures]
RunConditionalDecoderInputs = tuple[
NodeFeatures,
EdgeFeatures,
NeighborIndices,
AtomMask,
AutoRegressiveMask,
ProteinSequence,
]
RunConditionalDecoderFn = Callable[[*RunConditionalDecoderInputs], NodeFeatures]
RunAutoregressiveDecoderInputs = tuple[
PRNGKeyArray,
NodeFeatures,
EdgeFeatures,
NeighborIndices,
AtomMask,
AutoRegressiveMask,
float,
]
RunAutoregressiveDecoderFn = Callable[
[*RunAutoregressiveDecoderInputs],
tuple[OneHotProteinSequence, Logits],
]
[docs]
def decoder_parameter_pytree(
model_parameters: ModelParameters,
num_decoder_layers: int = 3,
) -> ModelParameters:
"""Make the model weights accessible as a PyTree.
Args:
model_parameters: Model parameters for the decoder.
num_decoder_layers: Number of decoder layers to set up.
Returns:
Decoder parameters as a PyTree.
"""
all_decoder_layer_params_list = []
for i in range(num_decoder_layers):
prefix = "protein_mpnn/~/dec_layer"
if i > 0:
prefix += f"_{i}"
layer_name_suffix = f"dec{i}"
layer_params_dict = {
"W1": model_parameters[f"{prefix}/~/{layer_name_suffix}_W1"],
"W2": model_parameters[f"{prefix}/~/{layer_name_suffix}_W2"],
"W3": model_parameters[f"{prefix}/~/{layer_name_suffix}_W3"],
"norm1": model_parameters[f"{prefix}/~/{layer_name_suffix}_norm1"],
"dense_W_in": model_parameters[
f"{prefix}/~/position_wise_feed_forward/~/{layer_name_suffix}_dense_W_in"
],
"dense_W_out": model_parameters[
f"{prefix}/~/position_wise_feed_forward/~/{layer_name_suffix}_dense_W_out"
],
"norm2": model_parameters[f"{prefix}/~/{layer_name_suffix}_norm2"],
}
all_decoder_layer_params_list.append(layer_params_dict)
return jax.tree_util.tree_map(lambda *args: jnp.stack(args), *all_decoder_layer_params_list)
[docs]
def embed_sequence(
model_parameters: ModelParameters,
one_hot_sequence: OneHotProteinSequence,
) -> NodeFeatures:
"""Embed a one-hot encoded sequence."""
w_s = model_parameters["protein_mpnn/~/embed_token"]["W_s"]
return one_hot_sequence @ w_s
[docs]
@jax.jit
def initialize_conditional_decoder(
one_hot_sequence: OneHotProteinSequence,
node_features: NodeFeatures,
edge_features: EdgeFeatures,
neighbor_indices: NeighborIndices,
layer_params: ModelParameters,
) -> tuple[NodeEdgeFeatures, SequenceEdgeFeatures]:
"""Initialize the decoder with node and edge features.
Args:
one_hot_sequence: One-hot encoded sequence of shape (num_residues, num_classes).
node_features: Node features of shape (num_atoms, num_features).
edge_features: EdgeFeatures of shape (num_atoms, num_neighbors, num_features).
neighbor_indices: Indices of neighboring nodes of shape (num_atoms, num_neighbors).
layer_params: ModelParameters for the embedding layer.
Returns:
A tuple of node-edge features and sequence-edge features.
"""
embedded_sequence = embed_sequence(layer_params, one_hot_sequence)
node_edge_features = concatenate_neighbor_nodes(
jnp.zeros_like(node_features),
edge_features,
neighbor_indices,
)
node_edge_features = concatenate_neighbor_nodes(
node_features,
node_edge_features,
neighbor_indices,
)
sequence_edge_features = concatenate_neighbor_nodes(
embedded_sequence,
edge_features,
neighbor_indices,
)
return node_edge_features, sequence_edge_features
[docs]
@jax.jit
def decode_message(
node_features: NodeFeatures,
edge_features: EdgeFeatures,
layer_params: ModelParameters,
) -> Message:
"""Decode node and edge features into messages.
Args:
node_features: Node features of shape (num_atoms, num_features).
edge_features: Edge features of shape (num_atoms, num_neighbors, num_features).
layer_params: Model parameters for the encoding layer.
Returns:
Message: decoded messages of shape (num_atoms, num_neighbors, num_features).
"""
node_features_expand = jnp.tile(
jnp.expand_dims(node_features, -2),
[1, edge_features.shape[-2], 1],
)
node_edge_features = jnp.concatenate([node_features_expand, edge_features], -1)
w1, b1, w2, b2, w3, b3 = (
layer_params["W1"]["w"],
layer_params["W1"]["b"],
layer_params["W2"]["w"],
layer_params["W2"]["b"],
layer_params["W3"]["w"],
layer_params["W3"]["b"],
)
message = GeLU(jnp.dot(GeLU(jnp.dot(node_edge_features, w1) + b1), w2) + b2)
return jnp.dot(message, w3) + b3
[docs]
@partial(jax.jit, static_argnames=("scale",))
def decoder_normalize(
message: Message,
node_features: NodeFeatures,
mask: AtomMask,
layer_params: ModelParameters,
scale: float = 30.0,
) -> NodeFeatures:
"""Normalize the decoded messages and update node features.
Args:
message: decoded messages of shape (num_atoms, num_neighbors, num_features).
node_features: Node features of shape (num_atoms, num_features).
mask: Atom mask indicating valid atoms.
layer_params: Model parameters for the normalization layer.
scale: Scaling factor for normalization.
Returns:
Updated node features after normalization.
"""
node_features = node_features + (jnp.sum(message, -2) / scale)
norm1_params = layer_params["norm1"]
node_features = layer_normalization(node_features, norm1_params)
node_features = node_features + dense_layer(layer_params, node_features)
norm2_params = layer_params["norm2"]
node_features = layer_normalization(node_features, norm2_params)
return mask[:, None] * node_features
[docs]
def make_decode_layer(
attention_mask_enum: MaskedAttentionEnum,
) -> MaskedAttentionDecoderFn | DecoderFn:
"""Create a function to run the decoder with given model parameters."""
if (
attention_mask_enum is MaskedAttentionEnum.NONE
or attention_mask_enum is MaskedAttentionEnum.CROSS
):
@partial(jax.jit, static_argnames=("scale",))
def decoder_fn(
node_features: NodeFeatures,
edge_features: EdgeFeatures,
mask: AtomMask,
layer_params: ModelParameters,
scale: float = 30.0,
) -> Message:
"""Run the decoder with the provided edge features and neighbor indices."""
message = decode_message(node_features, edge_features, layer_params)
return decoder_normalize(
message,
node_features,
mask,
layer_params,
scale,
)
return decoder_fn
@partial(jax.jit, static_argnames=("scale",))
def masked_attn_decoder_fn(
node_features: NodeFeatures,
edge_features: EdgeFeatures,
mask: AtomMask,
attention_mask: AttentionMask,
layer_params: ModelParameters,
scale: float = 30.0,
) -> Message:
"""Run the decoder with the provided edge features and neighbor indices."""
message = decode_message(node_features, edge_features, layer_params)
message = mask_attention(message, attention_mask)
return decoder_normalize(
message,
node_features,
mask,
layer_params,
scale,
)
return masked_attn_decoder_fn
[docs]
def setup_decoder(
model_parameters: ModelParameters,
attention_mask_enum: MaskedAttentionEnum,
decoding_enum: DecodingEnum,
num_decoder_layers: int = 3,
) -> tuple[ModelParameters, Callable[..., Message]]:
"""Set up the decoder parameters and initial node features."""
all_decoder_layer_params = decoder_parameter_pytree(model_parameters, num_decoder_layers)
if decoding_enum is DecodingEnum.CONDITIONAL:
decode_layer_fn = make_decode_layer(attention_mask_enum=MaskedAttentionEnum.CONDITIONAL)
else:
decode_layer_fn = make_decode_layer(attention_mask_enum=attention_mask_enum)
return all_decoder_layer_params, decode_layer_fn
[docs]
def _check_enums(
attention_mask_enum: MaskedAttentionEnum,
decoding_enum: DecodingEnum,
) -> None:
"""Check if the provided enums are valid."""
if not isinstance(attention_mask_enum, MaskedAttentionEnum):
msg = f"Unknown attention mask enum: {attention_mask_enum}"
raise TypeError(msg)
if not isinstance(decoding_enum, DecodingEnum):
msg = f"Unknown decoding enum: {decoding_enum}"
raise TypeError(msg)
[docs]
def make_decoder(
model_parameters: ModelParameters,
attention_mask_enum: MaskedAttentionEnum,
decoding_enum: DecodingEnum = DecodingEnum.UNCONDITIONAL,
num_decoder_layers: int = 3,
scale: float = 30.0,
) -> (
RunDecoderFn | RunMaskedAttentionDecoderFn | RunAutoregressiveDecoderFn | RunConditionalDecoderFn
):
"""Create a function to run the decoder with given model parameters."""
_check_enums(
attention_mask_enum,
decoding_enum,
)
all_decoder_layer_params, decode_layer_fn = setup_decoder(
model_parameters,
attention_mask_enum,
decoding_enum,
num_decoder_layers,
)
if decoding_enum is DecodingEnum.AUTOREGRESSIVE:
@jax.jit
def run_autoregressive_decoder(
prng_key: PRNGKeyArray,
node_features: NodeFeatures,
edge_features: EdgeFeatures,
neighbor_indices: NeighborIndices,
mask: AtomMask,
ar_mask: AutoRegressiveMask,
temperature: float = 1.0,
) -> tuple[OneHotProteinSequence, Logits]:
"""Run a full, efficient, local-update autoregressive sampling process."""
attention_mask = jnp.take_along_axis(ar_mask, neighbor_indices, axis=1)
mask_1d = mask[:, None]
mask_bw = mask_1d * attention_mask
mask_fw = mask_1d * (1 - attention_mask)
decoding_order = jnp.argsort(jnp.sum(ar_mask, axis=1))
context_edge_features = concatenate_neighbor_nodes(
jnp.zeros_like(node_features),
edge_features,
neighbor_indices,
)
fixed_context_features = concatenate_neighbor_nodes(
node_features,
context_edge_features,
neighbor_indices,
)
fixed_context_features = mask_fw[..., None] * fixed_context_features
def autoregressive_step(
carry: tuple[NodeFeatures, NodeFeatures, Logits],
scan_inputs: tuple[Array, PRNGKeyArray],
) -> tuple[
tuple[NodeFeatures, NodeFeatures, Logits],
OneHotProteinSequence, # MODIFIED: We only need to collect the sequence
]:
all_layers_node_features, embedded_sequence_state, all_logits = carry
position, key = scan_inputs
# ... (no changes inside this part of the function until the return)
fixed_context_features_position = fixed_context_features[position]
position_neighborhood_indices = neighbor_indices[position]
mask_position = mask[position]
mask_bw_position = mask_bw[position]
edge_features_position = edge_features[position]
neighbor_sequence_features = embedded_sequence_state[position_neighborhood_indices]
sequence_edge_features_position = concatenate_neighbor_nodes(
neighbor_sequence_features,
edge_features_position,
position_neighborhood_indices,
)
def decoder_layer_loop(
layer_num: int,
loop_all_layers_features: NodeFeatures,
) -> NodeFeatures:
layer_input_node_features = loop_all_layers_features[layer_num]
decoder_input_features_position = concatenate_neighbor_nodes(
layer_input_node_features,
sequence_edge_features_position,
position_neighborhood_indices,
)
total_input_features_position = (
mask_bw_position[..., None] * decoder_input_features_position
) + fixed_context_features_position
layer_params = jax.tree_util.tree_map(lambda x: x[layer_num], all_decoder_layer_params)
nf_batch = jnp.expand_dims(layer_input_node_features[position], 0)
ef_batch = jnp.expand_dims(total_input_features_position, 0)
mask_batch = jnp.expand_dims(mask_position, 0)
att_mask_batch = jnp.expand_dims(mask_bw_position, 0)
updated_node_features_position = decode_layer_fn(
nf_batch,
ef_batch,
mask=mask_batch,
layer_params=layer_params,
attention_mask=att_mask_batch,
scale=scale,
)
updated_node_features_position = updated_node_features_position.squeeze(0)
return loop_all_layers_features.at[layer_num + 1, position].set(
updated_node_features_position,
)
final_all_layers_node_features = jax.lax.fori_loop(
0,
num_decoder_layers,
decoder_layer_loop,
all_layers_node_features,
)
final_node_features_position = final_all_layers_node_features[-1, position]
logits_position = final_projection(model_parameters, final_node_features_position)
sampled_logits = logits_position / temperature + jax.random.gumbel(
key,
logits_position.shape,
)
sampled_idx = sampled_logits[..., :20].argmax(-1)
sequence_position = jax.nn.one_hot(sampled_idx, num_classes=21)
embedded_sequence_position = embed_sequence(model_parameters, sequence_position)
next_embedded_sequence_state = embedded_sequence_state.at[position].set(
embedded_sequence_position,
)
all_logits = all_logits.at[position].set(sampled_logits)
next_carry = (final_all_layers_node_features, next_embedded_sequence_state, all_logits)
outputs_position_collect = sequence_position
return next_carry, outputs_position_collect
num_residues = node_features.shape[0]
initial_all_layers_node_features = jnp.array(
[node_features] + [jnp.zeros_like(node_features)] * num_decoder_layers,
)
initial_embedded_sequence_state = jnp.zeros_like(
embed_sequence(
model_parameters,
jax.nn.one_hot(jnp.zeros(num_residues, dtype=jnp.int32), num_classes=21),
),
)
all_logits = jnp.zeros((num_residues, 21), dtype=jnp.float32)
initial_carry = (
initial_all_layers_node_features,
initial_embedded_sequence_state,
all_logits,
)
scan_inputs = (decoding_order, jax.random.split(prng_key, num_residues))
final_carry, sequence_in_decoding_order = jax.lax.scan(
autoregressive_step,
initial_carry,
scan_inputs,
)
final_all_logits = final_carry[2]
final_sequence = jnp.zeros_like(sequence_in_decoding_order)
final_sequence = final_sequence.at[decoding_order].set(sequence_in_decoding_order)
return final_sequence, final_all_logits
return run_autoregressive_decoder
if decoding_enum is DecodingEnum.CONDITIONAL:
@jax.jit
def run_conditional_decoder(
node_features: NodeFeatures,
edge_features: EdgeFeatures,
neighbor_indices: NeighborIndices,
mask: AtomMask,
ar_mask: AutoRegressiveMask,
one_hot_sequence: OneHotProteinSequence,
) -> NodeFeatures:
"""Run the decoder with the provided edge features and neighbor indices."""
node_edge_features, sequence_edge_features = initialize_conditional_decoder(
one_hot_sequence,
node_features,
edge_features,
neighbor_indices,
model_parameters,
)
attention_mask = jnp.take_along_axis(
ar_mask,
neighbor_indices,
axis=1,
)
mask_bw = mask[:, None] * attention_mask
mask_fw = mask[:, None] * (1 - attention_mask)
masked_node_edge_features = mask_fw[..., None] * node_edge_features
def decoder_loop_body(
i: Int,
carry: NodeFeatures,
) -> NodeFeatures:
loop_node_features = carry
current_layer_params = jax.tree_util.tree_map(lambda x: x[i], all_decoder_layer_params)
current_features = concatenate_neighbor_nodes(
loop_node_features,
sequence_edge_features,
neighbor_indices,
)
loop_edge_features = (mask_bw[..., None] * current_features) + masked_node_edge_features
return decode_layer_fn(
loop_node_features,
loop_edge_features,
mask,
attention_mask,
current_layer_params,
scale,
)
return jax.lax.fori_loop(
0,
num_decoder_layers,
decoder_loop_body,
node_features,
)
return run_conditional_decoder
if decoding_enum is DecodingEnum.UNCONDITIONAL:
if attention_mask_enum is MaskedAttentionEnum.NONE:
@jax.jit
def run_decoder(
node_features: NodeFeatures,
edge_features: EdgeFeatures,
mask: AtomMask,
) -> NodeFeatures:
"""Run the decoder with the provided edge features and neighbor indices."""
nodes_expanded = jnp.tile(
jnp.expand_dims(node_features, -2),
[1, edge_features.shape[1], 1],
)
zeros_expanded = jnp.tile(
jnp.expand_dims(jnp.zeros_like(node_features), -2),
[1, edge_features.shape[1], 1],
)
decoder_input_features = jnp.concatenate(
[nodes_expanded, zeros_expanded, edge_features],
-1,
)
def decoder_loop_body(
i: Int,
carry: NodeFeatures,
) -> NodeFeatures:
loop_node_features = carry
current_layer_params = jax.tree_util.tree_map(lambda x: x[i], all_decoder_layer_params)
return decode_layer_fn(
loop_node_features,
decoder_input_features,
mask,
current_layer_params,
scale,
)
return jax.lax.fori_loop(
0,
num_decoder_layers,
decoder_loop_body,
node_features,
)
return run_decoder
@jax.jit
def run_masked_attention_decoder(
node_features: NodeFeatures,
edge_features: EdgeFeatures,
mask: AtomMask,
attention_mask: AttentionMask,
) -> NodeFeatures:
"""Run the decoder with the provided edge features and neighbor indices."""
def decoder_loop_body(
i: Int,
carry: NodeFeatures,
) -> NodeFeatures:
loop_node_features = carry
current_layer_params = jax.tree_util.tree_map(lambda x: x[i], all_decoder_layer_params)
return decode_layer_fn(
loop_node_features,
edge_features,
mask,
attention_mask,
current_layer_params,
scale,
)
return jax.lax.fori_loop(
0,
num_decoder_layers,
decoder_loop_body,
node_features,
)
return run_masked_attention_decoder
msg = f"Unknown decoding enum: {decoding_enum}"
raise ValueError(msg)