openfold.model.evoformer

Classes

EvoformerBlock(c_m, c_z, c_hidden_msa_att, ...)

EvoformerStack(c_m, c_z, c_hidden_msa_att, ...)

Main Evoformer trunk.

ExtraMSABlock(c_m, c_z, c_hidden_msa_att, ...)

Almost identical to the standard EvoformerBlock, except in that the ExtraMSABlock uses GlobalAttention for MSA column attention and requires more fine-grained control over checkpointing.

ExtraMSAStack(c_m, c_z, c_hidden_msa_att, ...)

Implements Algorithm 18.

MSABlock(c_m, c_z, c_hidden_msa_att, ...)

MSATransition(c_m, n)

Feed-forward network applied to MSA activations after attention.

PairStack(c_z, c_hidden_mul, ...)

class EvoformerBlock(c_m, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, no_heads_msa, no_heads_pair, transition_n, msa_dropout, pair_dropout, no_column_attention, opm_first, fuse_projection_weights, inf, eps)

Bases: MSABlock

Parameters:
  • c_m (int)

  • c_z (int)

  • c_hidden_msa_att (int)

  • c_hidden_opm (int)

  • c_hidden_mul (int)

  • c_hidden_pair_att (int)

  • no_heads_msa (int)

  • no_heads_pair (int)

  • transition_n (int)

  • msa_dropout (float)

  • pair_dropout (float)

  • no_column_attention (bool)

  • opm_first (bool)

  • fuse_projection_weights (bool)

  • inf (float)

  • eps (float)

forward(m, z, msa_mask, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, use_flash=False, inplace_safe=False, _mask_trans=True, _attn_chunk_size=None, _offload_inference=False, _offloadable_inputs=None)
Parameters:
  • m (Tensor | None)

  • z (Tensor | None)

  • msa_mask (Tensor)

  • pair_mask (Tensor)

  • chunk_size (int | None)

  • use_deepspeed_evo_attention (bool)

  • use_cuequivariance_attention (bool)

  • use_cuequivariance_multiplicative_update (bool)

  • use_lma (bool)

  • use_flash (bool)

  • inplace_safe (bool)

  • _mask_trans (bool)

  • _attn_chunk_size (int | None)

  • _offload_inference (bool)

  • _offloadable_inputs (Sequence[Tensor] | None)

Return type:

Tuple[Tensor, Tensor]

class EvoformerStack(c_m, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, c_s, no_heads_msa, no_heads_pair, no_blocks, transition_n, msa_dropout, pair_dropout, no_column_attention, opm_first, fuse_projection_weights, blocks_per_ckpt, inf, eps, clear_cache_between_blocks=False, tune_chunk_size=False, **kwargs)

Bases: Module

Main Evoformer trunk.

Implements Algorithm 6.

Parameters:
  • c_m (int)

  • c_z (int)

  • c_hidden_msa_att (int)

  • c_hidden_opm (int)

  • c_hidden_mul (int)

  • c_hidden_pair_att (int)

  • c_s (int)

  • no_heads_msa (int)

  • no_heads_pair (int)

  • no_blocks (int)

  • transition_n (int)

  • msa_dropout (float)

  • pair_dropout (float)

  • no_column_attention (bool)

  • opm_first (bool)

  • fuse_projection_weights (bool)

  • blocks_per_ckpt (int)

  • inf (float)

  • eps (float)

  • clear_cache_between_blocks (bool)

  • tune_chunk_size (bool)

forward(m, z, msa_mask, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, use_flash=False, inplace_safe=False, _mask_trans=True)
Parameters:
  • m (Tensor) – [*, N_seq, N_res, C_m] MSA embedding

  • z (Tensor) – [*, N_res, N_res, C_z] pair embedding

  • msa_mask (Tensor) – [*, N_seq, N_res] MSA mask

  • pair_mask (Tensor) – [*, N_res, N_res] pair mask

  • chunk_size (int | None) – Inference-time subbatch size. Acts as a minimum if self.tune_chunk_size is True

  • use_deepspeed_evo_attention (bool) – Whether to use DeepSpeed memory efficient kernel. Mutually exclusive with use_lma and use_flash.

  • use_lma (bool) – Whether to use low-memory attention during inference. Mutually exclusive with use_flash and use_deepspeed_evo_attention.

  • use_flash (bool) – Whether to use FlashAttention where possible. Mutually exclusive with use_lma and use_deepspeed_evo_attention.

  • use_cuequivariance_attention (bool)

  • use_cuequivariance_multiplicative_update (bool)

  • inplace_safe (bool)

  • _mask_trans (bool)

Returns:

[*, N_seq, N_res, C_m] MSA embedding z:

[*, N_res, N_res, C_z] pair embedding

s:

[*, N_res, C_s] single embedding (or None if extra MSA stack)

Return type:

m

class ExtraMSABlock(c_m, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, no_heads_msa, no_heads_pair, transition_n, msa_dropout, pair_dropout, opm_first, fuse_projection_weights, inf, eps, ckpt)

Bases: MSABlock

Almost identical to the standard EvoformerBlock, except in that the ExtraMSABlock uses GlobalAttention for MSA column attention and requires more fine-grained control over checkpointing. Separated from its twin to preserve the TorchScript-ability of the latter.

Parameters:
forward(m, z, msa_mask, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False, _mask_trans=True, _attn_chunk_size=None, _offload_inference=False, _offloadable_inputs=None)
Parameters:
  • m (Tensor | None)

  • z (Tensor | None)

  • msa_mask (Tensor)

  • pair_mask (Tensor)

  • chunk_size (int | None)

  • use_deepspeed_evo_attention (bool)

  • use_cuequivariance_attention (bool)

  • use_cuequivariance_multiplicative_update (bool)

  • use_lma (bool)

  • inplace_safe (bool)

  • _mask_trans (bool)

  • _attn_chunk_size (int | None)

  • _offload_inference (bool)

  • _offloadable_inputs (Sequence[Tensor] | None)

Return type:

Tuple[Tensor, Tensor]

class ExtraMSAStack(c_m, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, no_heads_msa, no_heads_pair, no_blocks, transition_n, msa_dropout, pair_dropout, opm_first, fuse_projection_weights, inf, eps, ckpt, clear_cache_between_blocks=False, tune_chunk_size=False, **kwargs)

Bases: Module

Implements Algorithm 18.

Parameters:
  • c_m (int)

  • c_z (int)

  • c_hidden_msa_att (int)

  • c_hidden_opm (int)

  • c_hidden_mul (int)

  • c_hidden_pair_att (int)

  • no_heads_msa (int)

  • no_heads_pair (int)

  • no_blocks (int)

  • transition_n (int)

  • msa_dropout (float)

  • pair_dropout (float)

  • opm_first (bool)

  • fuse_projection_weights (bool)

  • inf (float)

  • eps (float)

  • ckpt (bool)

  • clear_cache_between_blocks (bool)

  • tune_chunk_size (bool)

forward(m, z, msa_mask, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False, _mask_trans=True)
Parameters:
  • m (Tensor) – [*, N_extra, N_res, C_m] extra MSA embedding

  • z (Tensor) – [*, N_res, N_res, C_z] pair embedding

  • chunk_size (int | None) – Inference-time subbatch size for Evoformer modules

  • use_deepspeed_evo_attention (bool) – Whether to use DeepSpeed memory-efficient kernel

  • use_lma (bool) – Whether to use low-memory attention during inference

  • msa_mask (Tensor | None) – Optional [*, N_extra, N_res] MSA mask

  • pair_mask (Tensor | None) – Optional [*, N_res, N_res] pair mask

  • use_cuequivariance_attention (bool)

  • use_cuequivariance_multiplicative_update (bool)

  • inplace_safe (bool)

  • _mask_trans (bool)

Returns:

[*, N_res, N_res, C_z] pair update

Return type:

Tensor

class MSABlock(c_m, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, no_heads_msa, no_heads_pair, transition_n, msa_dropout, pair_dropout, opm_first, fuse_projection_weights, inf, eps)

Bases: Module, ABC

Parameters:
  • c_m (int)

  • c_z (int)

  • c_hidden_msa_att (int)

  • c_hidden_opm (int)

  • c_hidden_mul (int)

  • c_hidden_pair_att (int)

  • no_heads_msa (int)

  • no_heads_pair (int)

  • transition_n (int)

  • msa_dropout (float)

  • pair_dropout (float)

  • opm_first (bool)

  • fuse_projection_weights (bool)

  • inf (float)

  • eps (float)

abstract forward(m, z, msa_mask, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_lma=False, use_flash=False, inplace_safe=False, _mask_trans=True, _attn_chunk_size=None, _offload_inference=False, _offloadable_inputs=None)
Parameters:
Return type:

Tuple[Tensor, Tensor]

class MSATransition(c_m, n)

Bases: Module

Feed-forward network applied to MSA activations after attention.

Implements Algorithm 9

forward(m, mask=None, chunk_size=None)
Parameters:
  • m (Tensor) – [*, N_seq, N_res, C_m] MSA activation

  • mask (Tensor | None) – [*, N_seq, N_res, C_m] MSA mask

  • chunk_size (int | None)

Returns:

[*, N_seq, N_res, C_m] MSA activation update

Return type:

m

class PairStack(c_z, c_hidden_mul, c_hidden_pair_att, no_heads_pair, transition_n, pair_dropout, fuse_projection_weights, inf, eps)

Bases: Module

Parameters:
forward(z, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False, _mask_trans=True, _attn_chunk_size=None)
Parameters:
  • z (Tensor)

  • pair_mask (Tensor)

  • chunk_size (int | None)

  • use_deepspeed_evo_attention (bool)

  • use_cuequivariance_attention (bool)

  • use_cuequivariance_multiplicative_update (bool)

  • use_lma (bool)

  • inplace_safe (bool)

  • _mask_trans (bool)

  • _attn_chunk_size (int | None)

Return type:

Tensor