openfold.utils.checkpointing

Functions

checkpoint_blocks(blocks, args, blocks_per_ckpt)

Chunk a list of blocks and run each chunk with activation checkpointing.

get_checkpoint_fn()

checkpoint_blocks(blocks, args, blocks_per_ckpt)

Chunk a list of blocks and run each chunk with activation checkpointing. We define a “block” as a callable whose only inputs are the outputs of the previous block.

Implements Subsection 1.11.8

Parameters:
  • blocks (List[Callable]) – List of blocks

  • args (List[Any]) – Tuple of arguments for the first block.

  • blocks_per_ckpt (int | None) – Size of each chunk. A higher value corresponds to fewer checkpoints, and trades memory for speed. If None, no checkpointing is performed.

Returns:

The output of the final block

Return type:

List[Any]

get_checkpoint_fn()