openfold.utils.loss

Classes

AlphaFoldLoss(config)

Aggregation of the various losses described in the supplement

Functions

backbone_loss(backbone_rigid_tensor, ...[, ...])

between_residue_bond_loss(...[, ...])

Flat-bottom loss to penalize structural violations between residues.

between_residue_clash_loss(...[, asym_id, ...])

Loss to penalize steric clashes between residues.

chain_center_of_mass_loss(all_atom_pred_pos, ...)

Computes chain centre-of-mass loss.

compute_fape(pred_frames, target_frames, ...)

Computes FAPE loss.

compute_plddt(logits)

compute_predicted_aligned_error(logits[, ...])

Computes aligned confidence metrics from logits.

compute_renamed_ground_truth(batch, ...[, eps])

Find optimal renaming of ground truth based on the predicted positions.

compute_tm(logits[, residue_weights, ...])

compute_violation_metrics(batch, ...)

Compute several metrics to assess the structural violations.

compute_violation_metrics_np(batch, ...)

distogram_loss(logits, pseudo_beta, ...[, ...])

experimentally_resolved_loss(logits, ...[, eps])

extreme_ca_ca_distance_violations(...[, ...])

Counts residues whose Ca is a large distance from its neighbour.

fape_loss(out, batch, config)

find_structural_violations(batch, ...)

Computes several checks for structural violations.

find_structural_violations_np(batch, ...)

lddt(all_atom_pred_pos, all_atom_positions, ...)

lddt_ca(all_atom_pred_pos, ...[, cutoff, ...])

lddt_loss(logits, all_atom_pred_pos, ...[, ...])

masked_msa_loss(logits, true_msa, bert_mask, ...)

Computes BERT-style masked MSA loss.

sidechain_loss(sidechain_frames, ...[, ...])

sigmoid_cross_entropy(logits, labels)

softmax_cross_entropy(logits, labels)

supervised_chi_loss(angles_sin_cos, ...[, eps])

Implements Algorithm 27 (torsionAngleLoss)

tm_loss(logits, final_affine_tensor, ...[, ...])

torsion_angle_loss(a, a_gt, a_alt_gt)

violation_loss(violations, atom14_atom_exists)

within_residue_violations(...[, ...])

Loss to penalize steric clashes within residues.

class AlphaFoldLoss(config)

Bases: Module

Aggregation of the various losses described in the supplement

forward(out, batch, _return_breakdown=False)
loss(out, batch, _return_breakdown=False)

Rename previous forward() as loss() so that can be reused in the subclass

backbone_loss(backbone_rigid_tensor, backbone_rigid_mask, traj, pair_mask=None, use_clamped_fape=None, clamp_distance=10.0, loss_unit_distance=10.0, eps=0.0001, **kwargs)
Parameters:
Return type:

Tensor

between_residue_bond_loss(pred_atom_positions, pred_atom_mask, residue_index, aatype, tolerance_factor_soft=12.0, tolerance_factor_hard=12.0, eps=1e-06)

Flat-bottom loss to penalize structural violations between residues.

This is a loss penalizing any violation of the geometry around the peptide bond between consecutive amino acids. This loss corresponds to Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.

Parameters:
  • pred_atom_positions (Tensor) – Atom positions in atom37/14 representation

  • pred_atom_mask (Tensor) – Atom mask in atom37/14 representation

  • residue_index (Tensor) – Residue index for given amino acid, this is assumed to be monotonically increasing.

  • aatype (Tensor) – Amino acid type of given residue

  • tolerance_factor_soft – soft tolerance factor measured in standard deviations of pdb distributions

  • tolerance_factor_hard – hard tolerance factor measured in standard deviations of pdb distributions

Returns:

  • ‘c_n_loss_mean’: Loss for peptide bond length violations

  • ’ca_c_n_loss_mean’: Loss for violations of bond angle around C spanned

    by CA, C, N

  • ’c_n_ca_loss_mean’: Loss for violations of bond angle around N spanned

    by C, N, CA

  • ’per_residue_loss_sum’: sum of all losses for each residue

  • ’per_residue_violation_mask’: mask denoting all residues with violation

    present.

Return type:

Dict containing

between_residue_clash_loss(atom14_pred_positions, atom14_atom_exists, atom14_atom_radius, residue_index, asym_id=None, overlap_tolerance_soft=1.5, overlap_tolerance_hard=1.5, eps=1e-10)

Loss to penalize steric clashes between residues.

This is a loss penalizing any steric clashes due to non bonded atoms in different peptides coming too close. This loss corresponds to the part with different residues of Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.

Parameters:
  • atom14_pred_positions (Tensor) – Predicted positions of atoms in global prediction frame

  • atom14_atom_exists (Tensor) – Mask denoting whether atom at positions exists for given amino acid type

  • atom14_atom_radius (Tensor) – Van der Waals radius for each atom.

  • residue_index (Tensor) – Residue index for given amino acid.

  • overlap_tolerance_soft – Soft tolerance factor.

  • overlap_tolerance_hard – Hard tolerance factor.

  • asym_id (Tensor | None)

Returns:

  • ‘mean_loss’: average clash loss

  • ’per_atom_loss_sum’: sum of all clash losses per atom, shape (N, 14)

  • ’per_atom_clash_mask’: mask whether atom clashes with any other atom

    shape (N, 14)

Return type:

Dict containing

chain_center_of_mass_loss(all_atom_pred_pos, all_atom_positions, all_atom_mask, asym_id, clamp_distance=-4.0, weight=0.05, eps=1e-10, **kwargs)

Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.

Parameters:
  • all_atom_pred_pos (Tensor) – [*, N_pts, 37, 3] All-atom predicted atom positions

  • all_atom_positions (Tensor) – [*, N_pts, 37, 3] Ground truth all-atom positions

  • all_atom_mask (Tensor) – [*, N_pts, 37] All-atom positions mask

  • asym_id (Tensor) – [*, N_pts] Chain asym IDs

  • clamp_distance (float) – Cutoff above which distance errors are disregarded

  • weight (float) – Weight for loss

  • eps (float) – Small value used to regularize denominators

Returns:

[*] loss tensor

Return type:

Tensor

compute_fape(pred_frames, target_frames, frames_mask, pred_positions, target_positions, positions_mask, length_scale, pair_mask=None, l1_clamp_distance=None, eps=1e-08)

Computes FAPE loss.

Parameters:
  • pred_frames (Rigid) – [*, N_frames] Rigid object of predicted frames

  • target_frames (Rigid) – [*, N_frames] Rigid object of ground truth frames

  • frames_mask (Tensor) – [*, N_frames] binary mask for the frames

  • pred_positions (Tensor) – [*, N_pts, 3] predicted atom positions

  • target_positions (Tensor) – [*, N_pts, 3] ground truth positions

  • positions_mask (Tensor) – [*, N_pts] positions mask

  • length_scale (float) – Length scale by which the loss is divided

  • pair_mask (Tensor | None) – [*, N_frames, N_pts] mask to use for separating intra- from inter-chain losses.

  • l1_clamp_distance (float | None) – Cutoff above which distance errors are disregarded

  • eps – Small value used to regularize denominators

Returns:

[*] loss tensor

Return type:

Tensor

compute_plddt(logits)
Parameters:

logits (Tensor)

Return type:

Tensor

compute_predicted_aligned_error(logits, max_bin=31, no_bins=64, **kwargs)

Computes aligned confidence metrics from logits.

Parameters:
  • logits (Tensor) – [*, num_res, num_res, num_bins] the logits output from PredictedAlignedErrorHead.

  • max_bin (int) – Maximum bin value

  • no_bins (int) – Number of bins

Returns:

[*, num_res, num_res, num_bins] the predicted

aligned error probabilities over bins for each residue pair.

predicted_aligned_error: [*, num_res, num_res] the expected aligned distance

error for each pair of residues.

max_predicted_aligned_error: [*] the maximum predicted error possible.

Return type:

aligned_confidence_probs

compute_renamed_ground_truth(batch, atom14_pred_positions, eps=1e-10)

Find optimal renaming of ground truth based on the predicted positions.

Alg. 26 “renameSymmetricGroundTruthAtoms”

This renamed ground truth is then used for all losses, such that each loss moves the atoms in the same direction.

Parameters:
  • batch (Dict[str, Tensor]) –

    Dictionary containing: * atom14_gt_positions: Ground truth positions. * atom14_alt_gt_positions: Ground truth positions with renaming swaps. * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by

    renaming swaps.

    • atom14_gt_exists: Mask for which atoms exist in ground truth.

    • atom14_alt_gt_exists: Mask for which atoms exist in ground truth

      after renaming.

    • atom14_atom_exists: Mask for whether each atom is part of the given

      amino acid type.

  • atom14_pred_positions (Tensor) – Array of atom positions in global frame with shape

Returns:

alt_naming_is_better: Array with 1.0 where alternative swap is better. renamed_atom14_gt_positions: Array of optimal ground truth positions

after renaming swaps are performed.

renamed_atom14_gt_exists: Mask after renaming swap is performed.

Return type:

Dictionary containing

compute_tm(logits, residue_weights=None, asym_id=None, interface=False, max_bin=31, no_bins=64, eps=1e-08, **kwargs)
Parameters:
Return type:

Tensor

compute_violation_metrics(batch, atom14_pred_positions, violations)

Compute several metrics to assess the structural violations.

Parameters:
Return type:

Dict[str, Tensor]

compute_violation_metrics_np(batch, atom14_pred_positions, violations)
Parameters:
Return type:

Dict[str, ndarray]

distogram_loss(logits, pseudo_beta, pseudo_beta_mask, min_bin=2.3125, max_bin=21.6875, no_bins=64, eps=1e-06, **kwargs)
experimentally_resolved_loss(logits, atom37_atom_exists, all_atom_mask, resolution, min_resolution, max_resolution, eps=1e-08, **kwargs)
Parameters:
Return type:

Tensor

extreme_ca_ca_distance_violations(pred_atom_positions, pred_atom_mask, residue_index, max_angstrom_tolerance=1.5, eps=1e-06)

Counts residues whose Ca is a large distance from its neighbour.

Measures the fraction of CA-CA pairs between consecutive amino acids that are more than ‘max_angstrom_tolerance’ apart.

Parameters:
  • pred_atom_positions (Tensor) – Atom positions in atom37/14 representation

  • pred_atom_mask (Tensor) – Atom mask in atom37/14 representation

  • residue_index (Tensor) – Residue index for given amino acid, this is assumed to be monotonically increasing.

  • max_angstrom_tolerance – Maximum distance allowed to not count as violation.

Returns:

Fraction of consecutive CA-CA pairs with violation.

Return type:

Tensor

fape_loss(out, batch, config)
Parameters:
Return type:

Tensor

find_structural_violations(batch, atom14_pred_positions, violation_tolerance_factor, clash_overlap_tolerance, **kwargs)

Computes several checks for structural violations.

Parameters:
Return type:

Dict[str, Tensor]

find_structural_violations_np(batch, atom14_pred_positions, config)
Parameters:
Return type:

Dict[str, ndarray]

lddt(all_atom_pred_pos, all_atom_positions, all_atom_mask, cutoff=15.0, eps=1e-10, per_residue=True)
Parameters:
Return type:

Tensor

lddt_ca(all_atom_pred_pos, all_atom_positions, all_atom_mask, cutoff=15.0, eps=1e-10, per_residue=True)
Parameters:
Return type:

Tensor

lddt_loss(logits, all_atom_pred_pos, all_atom_positions, all_atom_mask, resolution, cutoff=15.0, no_bins=50, min_resolution=0.1, max_resolution=3.0, eps=1e-10, **kwargs)
Parameters:
Return type:

Tensor

masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-08, **kwargs)

Computes BERT-style masked MSA loss. Implements subsection 1.9.9.

Parameters:
  • logits – [*, N_seq, N_res, 23] predicted residue distribution

  • true_msa – [*, N_seq, N_res] true MSA

  • bert_mask – [*, N_seq, N_res] MSA mask

Returns:

Masked MSA loss

sidechain_loss(sidechain_frames, sidechain_atom_pos, rigidgroups_gt_frames, rigidgroups_alt_gt_frames, rigidgroups_gt_exists, renamed_atom14_gt_positions, renamed_atom14_gt_exists, alt_naming_is_better, clamp_distance=10.0, length_scale=10.0, eps=0.0001, **kwargs)
Parameters:
Return type:

Tensor

sigmoid_cross_entropy(logits, labels)
softmax_cross_entropy(logits, labels)
supervised_chi_loss(angles_sin_cos, unnormalized_angles_sin_cos, aatype, seq_mask, chi_mask, chi_angles_sin_cos, chi_weight, angle_norm_weight, eps=1e-06, **kwargs)

Implements Algorithm 27 (torsionAngleLoss)

Parameters:
  • angles_sin_cos (Tensor) – [*, N, 7, 2] predicted angles

  • unnormalized_angles_sin_cos (Tensor) – The same angles, but unnormalized

  • aatype (Tensor) – [*, N] residue indices

  • seq_mask (Tensor) – [*, N] sequence mask

  • chi_mask (Tensor) – [*, N, 7] angle mask

  • chi_angles_sin_cos (Tensor) – [*, N, 7, 2] ground truth angles

  • chi_weight (float) – Weight for the angle component of the loss

  • angle_norm_weight (float) – Weight for the normalization component of the loss

Returns:

[*] loss tensor

Return type:

Tensor

tm_loss(logits, final_affine_tensor, backbone_rigid_tensor, backbone_rigid_mask, resolution, max_bin=31, no_bins=64, min_resolution=0.1, max_resolution=3.0, eps=1e-08, **kwargs)
Parameters:
torsion_angle_loss(a, a_gt, a_alt_gt)
violation_loss(violations, atom14_atom_exists, average_clashes=False, eps=1e-06, **kwargs)
Parameters:
Return type:

Tensor

within_residue_violations(atom14_pred_positions, atom14_atom_exists, atom14_dists_lower_bound, atom14_dists_upper_bound, tighten_bounds_for_loss=0.0, eps=1e-10)

Loss to penalize steric clashes within residues.

This is a loss penalizing any steric violations or clashes of non-bonded atoms in a given peptide. This loss corresponds to the part with the same residues of Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.

Parameters:
  • atom14_pred_positions ([*, N, 14, 3]) – Predicted positions of atoms in global prediction frame.

  • atom14_atom_exists ([*, N, 14]) – Mask denoting whether atom at positions exists for given amino acid type

  • atom14_dists_lower_bound ([*, N, 14]) – Lower bound on allowed distances.

  • atom14_dists_upper_bound ([*, N, 14]) – Upper bound on allowed distances

  • tighten_bounds_for_loss ([*, N]) – Extra factor to tighten loss

Returns:

  • ‘per_atom_loss_sum’ ([*, N, 14]):

    sum of all clash losses per atom, shape

  • ’per_atom_clash_mask’ ([*, N, 14]):

    mask whether atom clashes with any other atom shape

Return type:

Dict containing