openfold.model.embedders

Classes

ExtraMSAEmbedder(c_in, c_out, **kwargs)

Embeds unclustered MSA sequences.

InputEmbedder(tf_dim, msa_dim, c_z, c_m, ...)

Embeds a subset of the input features.

InputEmbedderMultimer(tf_dim, msa_dim, c_z, ...)

Embeds a subset of the input features.

PreembeddingEmbedder(tf_dim, ...)

Embeds the sequence pre-embedding passed to the model and the target_feat features.

RecyclingEmbedder(c_m, c_z, min_bin, ...[, inf])

Embeds the output of an iteration of the model for recycling.

TemplateEmbedder(config)

TemplateEmbedderMultimer(config)

TemplatePairEmbedder(c_in, c_out, **kwargs)

Embeds "template_pair_feat" features.

TemplatePairEmbedderMultimer(c_in, c_out, ...)

TemplateSingleEmbedder(c_in, c_out, **kwargs)

Embeds the "template_angle_feat" feature.

TemplateSingleEmbedderMultimer(c_in, c_out)

class ExtraMSAEmbedder(c_in, c_out, **kwargs)

Bases: Module

Embeds unclustered MSA sequences.

Implements Algorithm 2, line 15

Parameters:
forward(x)
Parameters:

x (Tensor) – [*, N_extra_seq, N_res, C_in] “extra_msa_feat” features

Returns:

[*, N_extra_seq, N_res, C_out] embedding

Return type:

Tensor

class InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k, **kwargs)

Bases: Module

Embeds a subset of the input features.

Implements Algorithms 3 (InputEmbedder) and 4 (relpos).

Parameters:
forward(tf, ri, msa, inplace_safe=False)
Parameters:
  • batch

    Dict containing “target_feat”:

    Features of shape [*, N_res, tf_dim]

    ”residue_index”:

    Features of shape [*, N_res]

    ”msa_feat”:

    Features of shape [*, N_clust, N_res, msa_dim]

  • tf (Tensor)

  • ri (Tensor)

  • msa (Tensor)

  • inplace_safe (bool)

Returns:

[*, N_clust, N_res, C_m] MSA embedding pair_emb:

[*, N_res, N_res, C_z] pair embedding

Return type:

msa_emb

relpos(ri)

Computes relative positional encodings

Implements Algorithm 4.

Parameters:

ri (Tensor) – “residue_index” features of shape [*, N]

class InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m, max_relative_idx, use_chain_relative, max_relative_chain, **kwargs)

Bases: Module

Embeds a subset of the input features.

Implements Algorithms 3 (InputEmbedder) and 4 (relpos).

Parameters:
  • tf_dim (int)

  • msa_dim (int)

  • c_z (int)

  • c_m (int)

  • max_relative_idx (int)

  • use_chain_relative (bool)

  • max_relative_chain (int)

forward(batch)
Return type:

Tuple[Tensor, Tensor]

relpos(batch)
class PreembeddingEmbedder(tf_dim, preembedding_dim, c_z, c_m, relpos_k, **kwargs)

Bases: Module

Embeds the sequence pre-embedding passed to the model and the target_feat features.

Parameters:
forward(tf, ri, preemb, inplace_safe=False)
Parameters:
Return type:

Tuple[Tensor, Tensor]

relpos(ri)

Computes relative positional encodings :param ri: “residue_index” feature of shape [*, N]

Returns:

Relative positional encoding of protein using the residue_index feature

Parameters:

ri (Tensor)

class RecyclingEmbedder(c_m, c_z, min_bin, max_bin, no_bins, inf=100000000.0, **kwargs)

Bases: Module

Embeds the output of an iteration of the model for recycling.

Implements Algorithm 32.

Parameters:
forward(m, z, x, inplace_safe=False)
Parameters:
  • m (Tensor) – First row of the MSA embedding. [*, N_res, C_m]

  • z (Tensor) – [*, N_res, N_res, C_z] pair embedding

  • x (Tensor) – [*, N_res, 3] predicted C_beta coordinates

  • inplace_safe (bool)

Returns:

[*, N_res, C_m] MSA embedding update z:

[*, N_res, N_res, C_z] pair embedding update

Return type:

m

class TemplateEmbedder(config)

Bases: Module

forward(batch, z, pair_mask, templ_dim, chunk_size, _mask_trans=True, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False)
Parameters:
  • use_cuequivariance_attention (bool)

  • use_cuequivariance_multiplicative_update (bool)

class TemplateEmbedderMultimer(config)

Bases: Module

forward(batch, z, padding_mask_2d, templ_dim, chunk_size, multichain_mask_2d, _mask_trans=True, use_deepspeed_evo_attention=False, use_cuequivariance_attention=False, use_cuequivariance_multiplicative_update=False, use_lma=False, inplace_safe=False)
Parameters:
  • use_cuequivariance_attention (bool)

  • use_cuequivariance_multiplicative_update (bool)

class TemplatePairEmbedder(c_in, c_out, **kwargs)

Bases: Module

Embeds “template_pair_feat” features.

Implements Algorithm 2, line 9.

Parameters:
forward(x)
Parameters:

x (Tensor) – [*, C_in] input tensor

Returns:

[*, C_out] output tensor

Return type:

Tensor

class TemplatePairEmbedderMultimer(c_in, c_out, c_dgram, c_aatype)

Bases: Module

Parameters:
forward(template_dgram, aatype_one_hot, query_embedding, pseudo_beta_mask, backbone_mask, multichain_mask_2d, unit_vector)
Parameters:
Return type:

Tensor

class TemplateSingleEmbedder(c_in, c_out, **kwargs)

Bases: Module

Embeds the “template_angle_feat” feature.

Implements Algorithm 2, line 7.

Parameters:
forward(x)
Parameters:

x (Tensor) – [*, N_templ, N_res, c_in] “template_angle_feat” features

Returns:

[*, N_templ, N_res, C_out] embedding

Return type:

x

class TemplateSingleEmbedderMultimer(c_in, c_out)

Bases: Module

Parameters:
forward(batch, atom_pos, aatype_one_hot)