openfold.model.template¶
Classes
|
Implements Algorithm 16. |
|
|
|
Implements Algorithm 17. |
Functions
|
|
|
|
- class TemplatePairStack(c_t, c_hidden_tri_att, c_hidden_tri_mul, no_blocks, no_heads, pair_transition_n, dropout_rate, tri_mul_first, fuse_projection_weights, blocks_per_ckpt, tune_chunk_size=False, inf=1000000000.0, **kwargs)¶
Bases:
ModuleImplements Algorithm 16.
- Parameters:
tune_chunk_size (bool)
- forward(t, mask, chunk_size, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False, _mask_trans=True)¶
- Parameters:
t (tensor) – [*, N_templ, N_res, N_res, C_t] template embedding
mask (tensor) – [*, N_templ, N_res, N_res] mask
chunk_size (int)
use_deepspeed_evo_attention (bool)
use_cuequivariance_attention (bool)
use_cuequivariance_multiplicative_update (bool)
use_lma (bool)
inplace_safe (bool)
_mask_trans (bool)
- Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
- class TemplatePairStackBlock(c_t, c_hidden_tri_att, c_hidden_tri_mul, no_heads, pair_transition_n, dropout_rate, tri_mul_first, fuse_projection_weights, inf, **kwargs)¶
Bases:
Module- Parameters:
- forward(z, mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False, _mask_trans=True, _attn_chunk_size=None)¶
- tri_att_start_end(single, _attn_chunk_size, single_mask, use_deepspeed_evo_attention, use_cuequivariance_attention, use_lma, inplace_safe)¶
- class TemplatePointwiseAttention(c_t, c_z, c_hidden, no_heads, inf, **kwargs)¶
Bases:
ModuleImplements Algorithm 17.
- forward(t, z, template_mask=None, chunk_size=256, use_lma=False)¶
- embed_templates_average(model, batch, z, pair_mask, templ_dim, templ_group_size=2, inplace_safe=False)¶
- Parameters:
model – An AlphaFold model object
batch – An AlphaFold input batch. See documentation of AlphaFold.
z – A [*, N, N, C_z] pair embedding
pair_mask – A [*, N, N] pair mask
templ_dim – The template dimension of the template tensors in batch
templ_group_size – Granularity of the approximation. Larger values trade memory for greater proximity to the original function
- Returns:
A dictionary of template pair and angle embeddings.
A memory-efficient approximation of the “embed_templates” method of the AlphaFold class. Instead of running pointwise attention over pair embeddings for all of the templates at the same time, it splits templates into groups of size templ_group_size, computes embeddings for each group normally, and then averages the group embeddings. In our experiments, this approximation has a minimal effect on the quality of the resulting embedding, while its low memory footprint allows the number of templates to scale almost indefinitely.
- embed_templates_offload(model, batch, z, pair_mask, templ_dim, template_chunk_size=256, inplace_safe=False)¶
- Parameters:
model – An AlphaFold model object
batch – An AlphaFold input batch. See documentation of AlphaFold.
z – A [*, N, N, C_z] pair embedding
pair_mask – A [*, N, N] pair mask
templ_dim – The template dimension of the template tensors in batch
template_chunk_size – Integer value controlling how quickly the offloaded pair embedding tensor is brought back into GPU memory. In dire straits, can be lowered to reduce memory consumption of this function even more.
- Returns:
A dictionary of template pair and angle embeddings.
A version of the “embed_templates” method of the AlphaFold class that offloads the large template pair tensor to CPU. Slower but more frugal with GPU memory than the original. Useful for long-sequence inference.