openfold.model.template

Classes

TemplatePairStack(c_t, c_hidden_tri_att, ...)

Implements Algorithm 16.

TemplatePairStackBlock(c_t, ...)

TemplatePointwiseAttention(c_t, c_z, ...)

Implements Algorithm 17.

Functions

embed_templates_average(model, batch, z, ...)

param model:

An AlphaFold model object

embed_templates_offload(model, batch, z, ...)

param model:

An AlphaFold model object

class TemplatePairStack(c_t, c_hidden_tri_att, c_hidden_tri_mul, no_blocks, no_heads, pair_transition_n, dropout_rate, tri_mul_first, fuse_projection_weights, blocks_per_ckpt, tune_chunk_size=False, inf=1000000000.0, **kwargs)

Bases: Module

Implements Algorithm 16.

Parameters:

tune_chunk_size (bool)

forward(t, mask, chunk_size, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False, _mask_trans=True)
Parameters:
  • t (tensor) – [*, N_templ, N_res, N_res, C_t] template embedding

  • mask (tensor) – [*, N_templ, N_res, N_res] mask

  • chunk_size (int)

  • use_deepspeed_evo_attention (bool)

  • use_cuequivariance_attention (bool)

  • use_cuequivariance_multiplicative_update (bool)

  • use_lma (bool)

  • inplace_safe (bool)

  • _mask_trans (bool)

Returns:

[*, N_templ, N_res, N_res, C_t] template embedding update

class TemplatePairStackBlock(c_t, c_hidden_tri_att, c_hidden_tri_mul, no_heads, pair_transition_n, dropout_rate, tri_mul_first, fuse_projection_weights, inf, **kwargs)

Bases: Module

Parameters:
  • c_t (int)

  • c_hidden_tri_att (int)

  • c_hidden_tri_mul (int)

  • no_heads (int)

  • pair_transition_n (int)

  • dropout_rate (float)

  • tri_mul_first (bool)

  • fuse_projection_weights (bool)

  • inf (float)

forward(z, mask, chunk_size=None, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False, _mask_trans=True, _attn_chunk_size=None)
Parameters:
  • z (Tensor)

  • mask (Tensor)

  • chunk_size (int | None)

  • use_deepspeed_evo_attention (bool)

  • use_cuequivariance_attention (bool)

  • use_cuequivariance_multiplicative_update (bool)

  • use_lma (bool)

  • inplace_safe (bool)

  • _mask_trans (bool)

  • _attn_chunk_size (int | None)

tri_att_start_end(single, _attn_chunk_size, single_mask, use_deepspeed_evo_attention, use_cuequivariance_attention, use_lma, inplace_safe)
Parameters:
  • single (Tensor)

  • _attn_chunk_size (int | None)

  • single_mask (Tensor)

  • use_deepspeed_evo_attention (bool)

  • use_cuequivariance_attention (bool)

  • use_lma (bool)

  • inplace_safe (bool)

tri_mul_out_in(single, single_mask, use_cuequivariance_multiplicative_update, inplace_safe)
Parameters:
  • single (Tensor)

  • single_mask (Tensor)

  • use_cuequivariance_multiplicative_update (bool)

  • inplace_safe (bool)

class TemplatePointwiseAttention(c_t, c_z, c_hidden, no_heads, inf, **kwargs)

Bases: Module

Implements Algorithm 17.

forward(t, z, template_mask=None, chunk_size=256, use_lma=False)
Parameters:
  • t (Tensor) – [*, N_templ, N_res, N_res, C_t] template embedding

  • z (Tensor) – [*, N_res, N_res, C_t] pair embedding

  • template_mask (Tensor | None) – [*, N_templ] template mask

  • chunk_size (int | None)

  • use_lma (bool)

Returns:

[*, N_res, N_res, C_z] pair embedding update

Return type:

Tensor

embed_templates_average(model, batch, z, pair_mask, templ_dim, templ_group_size=2, inplace_safe=False)
Parameters:
  • model – An AlphaFold model object

  • batch – An AlphaFold input batch. See documentation of AlphaFold.

  • z – A [*, N, N, C_z] pair embedding

  • pair_mask – A [*, N, N] pair mask

  • templ_dim – The template dimension of the template tensors in batch

  • templ_group_size – Granularity of the approximation. Larger values trade memory for greater proximity to the original function

Returns:

A dictionary of template pair and angle embeddings.

A memory-efficient approximation of the “embed_templates” method of the AlphaFold class. Instead of running pointwise attention over pair embeddings for all of the templates at the same time, it splits templates into groups of size templ_group_size, computes embeddings for each group normally, and then averages the group embeddings. In our experiments, this approximation has a minimal effect on the quality of the resulting embedding, while its low memory footprint allows the number of templates to scale almost indefinitely.

embed_templates_offload(model, batch, z, pair_mask, templ_dim, template_chunk_size=256, inplace_safe=False)
Parameters:
  • model – An AlphaFold model object

  • batch – An AlphaFold input batch. See documentation of AlphaFold.

  • z – A [*, N, N, C_z] pair embedding

  • pair_mask – A [*, N, N] pair mask

  • templ_dim – The template dimension of the template tensors in batch

  • template_chunk_size – Integer value controlling how quickly the offloaded pair embedding tensor is brought back into GPU memory. In dire straits, can be lowered to reduce memory consumption of this function even more.

Returns:

A dictionary of template pair and angle embeddings.

A version of the “embed_templates” method of the AlphaFold class that offloads the large template pair tensor to CPU. Slower but more frugal with GPU memory than the original. Useful for long-sequence inference.