openfold.model.evoformer¶
Classes
|
|
|
Main Evoformer trunk. |
|
Almost identical to the standard EvoformerBlock, except in that the ExtraMSABlock uses GlobalAttention for MSA column attention and requires more fine-grained control over checkpointing. |
|
Implements Algorithm 18. |
|
|
|
Feed-forward network applied to MSA activations after attention. |
|
- class EvoformerBlock(c_m, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, no_heads_msa, no_heads_pair, transition_n, msa_dropout, pair_dropout, no_column_attention, opm_first, fuse_projection_weights, inf, eps)¶
Bases:
MSABlock- Parameters:
c_m (int)
c_z (int)
c_hidden_msa_att (int)
c_hidden_opm (int)
c_hidden_mul (int)
c_hidden_pair_att (int)
no_heads_msa (int)
no_heads_pair (int)
transition_n (int)
msa_dropout (float)
pair_dropout (float)
no_column_attention (bool)
opm_first (bool)
fuse_projection_weights (bool)
inf (float)
eps (float)
- forward(m, z, msa_mask, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, use_flash=False, inplace_safe=False, _mask_trans=True, _attn_chunk_size=None, _offload_inference=False, _offloadable_inputs=None)¶
- Parameters:
m (Tensor | None)
z (Tensor | None)
msa_mask (Tensor)
pair_mask (Tensor)
chunk_size (int | None)
use_deepspeed_evo_attention (bool)
use_cuequivariance_attention (bool)
use_cuequivariance_multiplicative_update (bool)
use_lma (bool)
use_flash (bool)
inplace_safe (bool)
_mask_trans (bool)
_attn_chunk_size (int | None)
_offload_inference (bool)
- Return type:
- class EvoformerStack(c_m, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, c_s, no_heads_msa, no_heads_pair, no_blocks, transition_n, msa_dropout, pair_dropout, no_column_attention, opm_first, fuse_projection_weights, blocks_per_ckpt, inf, eps, clear_cache_between_blocks=False, tune_chunk_size=False, **kwargs)¶
Bases:
ModuleMain Evoformer trunk.
Implements Algorithm 6.
- Parameters:
c_m (int)
c_z (int)
c_hidden_msa_att (int)
c_hidden_opm (int)
c_hidden_mul (int)
c_hidden_pair_att (int)
c_s (int)
no_heads_msa (int)
no_heads_pair (int)
no_blocks (int)
transition_n (int)
msa_dropout (float)
pair_dropout (float)
no_column_attention (bool)
opm_first (bool)
fuse_projection_weights (bool)
blocks_per_ckpt (int)
inf (float)
eps (float)
clear_cache_between_blocks (bool)
tune_chunk_size (bool)
- forward(m, z, msa_mask, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, use_flash=False, inplace_safe=False, _mask_trans=True)¶
- Parameters:
chunk_size (int | None) – Inference-time subbatch size. Acts as a minimum if self.tune_chunk_size is True
use_deepspeed_evo_attention (bool) – Whether to use DeepSpeed memory efficient kernel. Mutually exclusive with use_lma and use_flash.
use_lma (bool) – Whether to use low-memory attention during inference. Mutually exclusive with use_flash and use_deepspeed_evo_attention.
use_flash (bool) – Whether to use FlashAttention where possible. Mutually exclusive with use_lma and use_deepspeed_evo_attention.
use_cuequivariance_attention (bool)
use_cuequivariance_multiplicative_update (bool)
inplace_safe (bool)
_mask_trans (bool)
- Returns:
[*, N_seq, N_res, C_m] MSA embedding z:
[*, N_res, N_res, C_z] pair embedding
- s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
- Return type:
m
- class ExtraMSABlock(c_m, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, no_heads_msa, no_heads_pair, transition_n, msa_dropout, pair_dropout, opm_first, fuse_projection_weights, inf, eps, ckpt)¶
Bases:
MSABlockAlmost identical to the standard EvoformerBlock, except in that the ExtraMSABlock uses GlobalAttention for MSA column attention and requires more fine-grained control over checkpointing. Separated from its twin to preserve the TorchScript-ability of the latter.
- Parameters:
- forward(m, z, msa_mask, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False, _mask_trans=True, _attn_chunk_size=None, _offload_inference=False, _offloadable_inputs=None)¶
- Parameters:
m (Tensor | None)
z (Tensor | None)
msa_mask (Tensor)
pair_mask (Tensor)
chunk_size (int | None)
use_deepspeed_evo_attention (bool)
use_cuequivariance_attention (bool)
use_cuequivariance_multiplicative_update (bool)
use_lma (bool)
inplace_safe (bool)
_mask_trans (bool)
_attn_chunk_size (int | None)
_offload_inference (bool)
- Return type:
- class ExtraMSAStack(c_m, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, no_heads_msa, no_heads_pair, no_blocks, transition_n, msa_dropout, pair_dropout, opm_first, fuse_projection_weights, inf, eps, ckpt, clear_cache_between_blocks=False, tune_chunk_size=False, **kwargs)¶
Bases:
ModuleImplements Algorithm 18.
- Parameters:
c_m (int)
c_z (int)
c_hidden_msa_att (int)
c_hidden_opm (int)
c_hidden_mul (int)
c_hidden_pair_att (int)
no_heads_msa (int)
no_heads_pair (int)
no_blocks (int)
transition_n (int)
msa_dropout (float)
pair_dropout (float)
opm_first (bool)
fuse_projection_weights (bool)
inf (float)
eps (float)
ckpt (bool)
clear_cache_between_blocks (bool)
tune_chunk_size (bool)
- forward(m, z, msa_mask, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False, _mask_trans=True)¶
- Parameters:
chunk_size (int | None) – Inference-time subbatch size for Evoformer modules
use_deepspeed_evo_attention (bool) – Whether to use DeepSpeed memory-efficient kernel
use_lma (bool) – Whether to use low-memory attention during inference
msa_mask (Tensor | None) – Optional [*, N_extra, N_res] MSA mask
pair_mask (Tensor | None) – Optional [*, N_res, N_res] pair mask
use_cuequivariance_attention (bool)
use_cuequivariance_multiplicative_update (bool)
inplace_safe (bool)
_mask_trans (bool)
- Returns:
[*, N_res, N_res, C_z] pair update
- Return type:
- class MSABlock(c_m, c_z, c_hidden_msa_att, c_hidden_opm, c_hidden_mul, c_hidden_pair_att, no_heads_msa, no_heads_pair, transition_n, msa_dropout, pair_dropout, opm_first, fuse_projection_weights, inf, eps)¶
-
- Parameters:
- abstract forward(m, z, msa_mask, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_lma=False, use_flash=False, inplace_safe=False, _mask_trans=True, _attn_chunk_size=None, _offload_inference=False, _offloadable_inputs=None)¶
- Parameters:
- Return type:
- class MSATransition(c_m, n)¶
Bases:
ModuleFeed-forward network applied to MSA activations after attention.
Implements Algorithm 9
- class PairStack(c_z, c_hidden_mul, c_hidden_pair_att, no_heads_pair, transition_n, pair_dropout, fuse_projection_weights, inf, eps)¶
Bases:
Module- Parameters:
- forward(z, pair_mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False, _mask_trans=True, _attn_chunk_size=None)¶