openfold.model.msa

Classes

MSAAttention(c_in, c_hidden, no_heads[, ...])

MSAColumnAttention(c_m, c_hidden, no_heads)

Implements Algorithm 8.

MSAColumnGlobalAttention(c_in, c_hidden, ...)

MSARowAttentionWithPairBias(c_m, c_z, ...[, inf])

Implements Algorithm 7.

class MSAAttention(c_in, c_hidden, no_heads, pair_bias=False, c_z=None, inf=1000000000.0)

Bases: Module

forward(m, z=None, mask=None, chunk_size=None, use_memory_efficient_kernel=False, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_lma=False, use_flash=False, inplace_safe=False, _chunk_logits=None, _checkpoint_chunks=None)
Parameters:
  • m (Tensor) – [*, N_seq, N_res, C_m] MSA embedding

  • z (Tensor | None) – [*, N_res, N_res, C_z] pair embedding. Required only if pair_bias is True

  • mask (Tensor | None) – [*, N_seq, N_res] MSA mask

  • chunk_size (int | None) – Size of chunks into which the inputs are split along their batch dimensions. A low value decreases memory overhead at the cost of slower execution. Chunking is not performed by default.

  • use_memory_efficient_kernel (bool)

  • use_deepspeed_evo_attention (bool)

  • use_cuequivariance_attention (bool)

  • use_lma (bool)

  • use_flash (bool)

  • inplace_safe (bool)

  • _chunk_logits (int | None)

  • _checkpoint_chunks (bool | None)

Return type:

Tensor

class MSAColumnAttention(c_m, c_hidden, no_heads, inf=1000000000.0)

Bases: Module

Implements Algorithm 8.

By rights, this should also be a subclass of MSAAttention. Alas, most inheritance isn’t supported by TorchScript.

forward(m, mask=None, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_lma=False, use_flash=False)
Parameters:
  • m (Tensor) – [*, N_seq, N_res, C_m] MSA embedding

  • mask (Tensor | None) – [*, N_seq, N_res] MSA mask

  • chunk_size (int | None) – Size of chunks into which the inputs are split along their batch dimensions. A low value decreases memory overhead at the cost of slower execution. Chunking is not performed by default.

  • use_deepspeed_evo_attention (bool)

  • use_cuequivariance_attention (bool)

  • use_lma (bool)

  • use_flash (bool)

Return type:

Tensor

class MSAColumnGlobalAttention(c_in, c_hidden, no_heads, inf=1000000000.0, eps=1e-10)

Bases: Module

forward(m, mask=None, chunk_size=None, use_lma=False)
Parameters:
Return type:

Tensor

class MSARowAttentionWithPairBias(c_m, c_z, c_hidden, no_heads, inf=1000000000.0)

Bases: MSAAttention

Implements Algorithm 7.