openfold.model.triangular_multiplicative_update

Classes

BaseTriangleMultiplicativeUpdate(c_z, ...)

Implements Algorithms 11 and 12.

FusedTriangleMultiplicationIncoming(c_z, ...)

Implements Algorithm 12.

FusedTriangleMultiplicationOutgoing(c_z, ...)

Implements Algorithm 11.

FusedTriangleMultiplicativeUpdate(c_z, c_hidden)

Implements Algorithms 11 and 12.

TriangleMultiplicationIncoming(c_z, c_hidden, *)

Implements Algorithm 12.

TriangleMultiplicationOutgoing(c_z, c_hidden, *)

Implements Algorithm 11.

TriangleMultiplicativeUpdate(c_z, c_hidden)

Implements Algorithms 11 and 12.

class BaseTriangleMultiplicativeUpdate(c_z, c_hidden, _outgoing)

Bases: Module, ABC

Implements Algorithms 11 and 12.

abstract forward(z, mask=None, inplace_safe=False, use_cuequivariance_multiplicative_update=False, _add_with_inplace=False)
Parameters:
  • x – [*, N_res, N_res, C_z] input tensor

  • mask (Tensor | None) – [*, N_res, N_res] input mask

  • z (Tensor)

  • inplace_safe (bool)

  • use_cuequivariance_multiplicative_update (bool)

  • _add_with_inplace (bool)

Returns:

[*, N_res, N_res, C_z] output tensor

Return type:

Tensor

class FusedTriangleMultiplicationIncoming(c_z, c_hidden, *, _outgoing=False)

Bases: FusedTriangleMultiplicativeUpdate

Implements Algorithm 12.

class FusedTriangleMultiplicationOutgoing(c_z, c_hidden, *, _outgoing=True)

Bases: FusedTriangleMultiplicativeUpdate

Implements Algorithm 11.

class FusedTriangleMultiplicativeUpdate(c_z, c_hidden, _outgoing=True)

Bases: BaseTriangleMultiplicativeUpdate

Implements Algorithms 11 and 12.

forward(z, mask=None, inplace_safe=False, use_cuequivariance_multiplicative_update=False, _add_with_inplace=False, _inplace_chunk_size=256)
Parameters:
  • x – [*, N_res, N_res, C_z] input tensor

  • mask (Tensor | None) – [*, N_res, N_res] input mask

  • z (Tensor)

  • inplace_safe (bool)

  • use_cuequivariance_multiplicative_update (bool)

  • _add_with_inplace (bool)

  • _inplace_chunk_size (int | None)

Returns:

[*, N_res, N_res, C_z] output tensor

Return type:

Tensor

class TriangleMultiplicationIncoming(c_z, c_hidden, *, _outgoing=False)

Bases: TriangleMultiplicativeUpdate

Implements Algorithm 12.

class TriangleMultiplicationOutgoing(c_z, c_hidden, *, _outgoing=True)

Bases: TriangleMultiplicativeUpdate

Implements Algorithm 11.

class TriangleMultiplicativeUpdate(c_z, c_hidden, _outgoing=True)

Bases: BaseTriangleMultiplicativeUpdate

Implements Algorithms 11 and 12.

forward(z, mask=None, inplace_safe=False, use_cuequivariance_multiplicative_update=False, _add_with_inplace=False, _inplace_chunk_size=256)
Parameters:
  • x – [*, N_res, N_res, C_z] input tensor

  • mask (Tensor | None) – [*, N_res, N_res] input mask

  • z (Tensor)

  • inplace_safe (bool)

  • use_cuequivariance_multiplicative_update (bool)

  • _add_with_inplace (bool)

  • _inplace_chunk_size (int | None)

Returns:

[*, N_res, N_res, C_z] output tensor

Return type:

Tensor