openfold.model.pair_transition

Classes

PairTransition(c_z, n)

Implements Algorithm 15.

class PairTransition(c_z, n)

Bases: Module

Implements Algorithm 15.

forward(z, mask=None, chunk_size=None)
Parameters:
  • z (Tensor) – [*, N_res, N_res, C_z] pair embedding

  • mask (Tensor | None)

  • chunk_size (int | None)

Returns:

[*, N_res, N_res, C_z] pair embedding update

Return type:

Tensor