openfold.model.primitives¶
Classes
|
Standard multi-head attention using AlphaFold's default layer initialization. |
|
|
|
|
|
A Linear layer with built-in nonstandard initializations. |
Functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Softmax, but without automatic casting to fp32 when the input is of type bfloat16 |
|
- class Attention(c_q, c_k, c_v, c_hidden, no_heads, gating=True, inf=1000000000.0)¶
Bases:
ModuleStandard multi-head attention using AlphaFold’s default layer initialization. Allows multiple bias vectors.
- forward(q_x, kv_x, biases=None, use_memory_efficient_kernel=False, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_lma=False, lma_q_chunk_size=1024, lma_kv_chunk_size=4096, use_flash=False, flash_mask=None)¶
- Parameters:
biases (List[Tensor] | None) – List of biases that broadcast to [*, H, Q, K]
use_memory_efficient_kernel (bool) – Whether to use a custom memory-efficient attention kernel. This should be the default choice for most. If none of the “use_<…>” flags are True, a stock PyTorch implementation is used instead
use_deepspeed_evo_attention (bool) – Whether to use DeepSpeed memory-efficient attention kernel. If none of the “use_<…>” flags are True, a stock PyTorch implementation is used instead
use_lma (bool) – Whether to use low-memory attention (Staats & Rabe 2021). If none of the “use_<…>” flags are True, a stock PyTorch implementation is used instead
lma_q_chunk_size (int) – Query chunk size (for LMA)
lma_kv_chunk_size (int) – Key/Value chunk size (for LMA)
use_cuequivariance_attention (bool) – Whether to use cuEquivariance attention kernel. When on, biases[0] contains 0/1 mask tensor for cuEquivariance attention (0 for invalid positions)
use_flash (bool)
flash_mask (Tensor | None)
- Return type:
- Returns
[*, Q, C_q] attention update
- class Linear(in_dim, out_dim, bias=True, init='default', init_fn=None, precision=None)¶
Bases:
LinearA Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found in the code.
- Parameters:
- final_init_(weights)¶
- gating_init_(weights)¶
- glorot_uniform_init_(weights)¶
- he_normal_init_(weights)¶
- ipa_point_weights_init_(weights)¶
- lecun_normal_init_(weights)¶
- normal_init_(weights)¶
- softmax_no_cast(t, dim=-1)¶
Softmax, but without automatic casting to fp32 when the input is of type bfloat16
- trunc_normal_init_(weights, scale=1.0, fan='fan_in')¶