openfold.utils.multi_chain_permutation¶
Functions
|
Calculate an input mask for downstream optimal transformation computation |
|
Takes selected anchor ground truth c-alpha positions and selected predicted anchor c-alpha position then calculate the optimal rotation matrix to align ground-truth anchor and predicted anchor |
|
A method that permutes chains in ground truth before calculating the loss because the mapping between the predicted and ground-truth will become arbitrary. |
|
Function to calculate RMSD between predicted and ground truth atom position |
|
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity. |
First check how many subunit(s) one sequence has. |
|
|
A function that obtain the transformation that optimally align src_atoms with tgt_atoms |
|
A function that retrieve which residues belong to which asym_id |
|
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper: Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034 |
|
Calculate the best rotation that minimises the RMSD between P and Q. |
|
Merge ground truth labels according to the permutation results |
|
Compute multi-chain permutation alignment. |
|
Pad input feature tensor. |
|
Splits ground truth features according to chains |
- calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue, asym_mask, pred_ca_mask)¶
Calculate an input mask for downstream optimal transformation computation
- Parameters:
true_ca_masks (List[Tensor]) – list of masks from ground truth chains.
anchor_gt_idx (Tensor) – a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue (Tensor) – a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
asym_mask (Tensor) – Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor) – ca mask from predicted structure.
- Returns:
A boolean mask
- Return type:
input_mask (Tensor)
- calculate_optimal_transform(true_ca_poses, anchor_gt_idx, anchor_gt_residue, true_ca_masks, pred_ca_mask, asym_mask, pred_ca_pos)¶
Takes selected anchor ground truth c-alpha positions and selected predicted anchor c-alpha position then calculate the optimal rotation matrix to align ground-truth anchor and predicted anchor
- Parameters:
true_ca_poses (List[Tensor]) – a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
anchor_gt_idx (Tensor) – a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue (Tensor) – a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
true_ca_masks (List[Tensor]) – list of masks from ground truth chains e.g. it will be length=5 if there are 5 chains in ground truth structure
pred_ca_mask (Tensor) – A boolean tensor corresponds to the mask to mask the predicted features
asym_mask (Tensor) – A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id
pred_ca_pos (Tensor) – a [nres*3] tensor of predicted c-alpha atom positions
- Return type:
Process: 1) select an achor chain from ground truth, denoted by anchor_gt_idx, and an chor chain from the predicted structure. Both anchor_gt and anchor_pred have exactly the same sequence 2) obtain the C-alpha positions corresponding to the selected anchor_gt, done be slicing the true_ca_pose according to anchor_gt_residue 3) calculate the optimal transformation that can best align the C-alpha atoms of anchor_pred to those of anchor_gt, done by Kabsch algorithm: source https://en.wikipedia.org/wiki/Kabsch_algorithm
Returns: a rotation matrix that record the optimal rotation that will best align selected anchor prediction to selected anchor truth a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
- compute_permutation_alignment(out, features, ground_truth)¶
A method that permutes chains in ground truth before calculating the loss because the mapping between the predicted and ground-truth will become arbitrary. The model cannot be assumed to predict chains in the same order as the ground truth. Thus, this function pick the optimal permutaion of predicted chains that best matches the ground truth, by minimising the RMSD i.e. the best permutation of ground truth chains is selected based on which permutation has the lowest RMSD calculation
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
- Parameters:
out (Dict[str, Tensor]) – a dictionary of output tensors from model.forward()
features (Dict[str, Tensor]) – a dictionary of feature tensors that are used as input for model.forward()
ground_truth (List[Dict[str, Tensor]]) – a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
- Returns:
a list of tuple(int,int) that instructs how ground truth chains should be permutated a dictionary recording which residues belong to which aysm_id
- Return type:
- compute_rmsd(true_atom_pos, pred_atom_pos, atom_mask=None, eps=1e-06)¶
Function to calculate RMSD between predicted and ground truth atom position
- get_entity_2_asym_list(features)¶
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
- get_least_asym_entity_or_longest_length(batch, input_asym_id)¶
First check how many subunit(s) one sequence has. Select the subunit that is less common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest, then choose one of the corresponding subunits as anchor
- Parameters:
- Returns:
Tensor(int) selected ground truth asym_id anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
- Return type:
anchor_gt_asym_id
- get_optimal_transform(src_atoms, tgt_atoms, mask=None)¶
A function that obtain the transformation that optimally align src_atoms with tgt_atoms
- Parameters:
- Return type:
Returns: a rotation matrix that record the optimal rotation that will best align selected anchor prediction to selected anchor truth a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
- get_per_asym_residue_index(features)¶
A function that retrieve which residues belong to which asym_id
- greedy_align(batch, per_asym_residue_index, entity_2_asym_list, pred_ca_pos, pred_ca_mask, true_ca_poses, true_ca_masks)¶
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper: Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
- Parameters:
batch (dict) – a dictionary of ground truth features
per_asym_residue_index (dict) – a dictionary recording which residues belong to which aysm_id
entity_2_asym_list (dict) – a dictionary recording which asym_id(s) belong to which entity_id
pred_ca_pos (Tensor) – predicted positions of c-alpha atoms from the results of model.forward()
pred_ca_mask (Tensor) – a boolean tensor that masks pred_ca_pos
true_ca_poses (list) – a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
true_ca_masks (list) – a list of tensors, corresponding to the masks of c-alpha positions of the ground truth structure. If there are 5 chains, this list will have a length of 5
- Returns:
A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated e.g. if 3 chains in the imput model have the same sequences, an example return would be: [(0,2),(1,1),(2,0)], meaning the 1st chain in the predicted structure should be aligned to the 3rd chain in the ground truth, and the 2nd chain in the predicted structure is ok to stay with the 2nd chain in the ground truth.
- Return type:
Note: the tuples in the returned list begin with 0 indexing but aym_id begins with 1. The reason why tuples in the return are 0-indexing is that at the stage of loss calculation, the ground truth atom positions: true_ca_poses, are already split up into a list of matrices. Hence, now this function needs to return tuples that provide the index to select from the list: true_ca_poses, and list index starts from 0.
- kabsch_rotation(P, Q)¶
Calculate the best rotation that minimises the RMSD between P and Q.
The optimal rotation matrix was calculated using Kabsch algorithm: https://en.wikipedia.org/wiki/Kabsch_algorithm
- merge_labels(per_asym_residue_index, labels, align, original_nres)¶
Merge ground truth labels according to the permutation results
- Parameters:
per_asym_residue_index (Dict[int, List[int]]) – a dictionary recording which residues belong to which aysm_id
labels (List[Dict]) – list of original ground truth feats e.g. if there’re 5 chains, labels will have a length of 5
align (List[Tuple[int, int]]) – list of tuples, each entry specify the corresponding label of the asym.
original_nres (int) – int, corresponding to the number of residues specified by crop_size in config.py
- Returns:
A new dictionary of permuated ground truth features
- Return type:
modified based on UniFold: https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1
- multi_chain_permutation_align(out, features, ground_truth)¶
Compute multi-chain permutation alignment.
- Parameters:
out (Dict[str, Tensor]) – a dictionary of output tensors from model.forward()
features (Dict[str, Tensor]) – a dictionary of feature tensors that are used as input for model.forward()
ground_truth (List[Dict[str, Tensor]]) – a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
- Returns:
a dictionary with updated ground truth feature tensors, ready for downstream loss calculations.
- Return type:
features
- pad_features(feature_tensor, nres_pad, pad_dim)¶
Pad input feature tensor. Padding values will be 0 and put behind the true feature values
- split_ground_truth_labels(gt_features)¶
Splits ground truth features according to chains
Args: gt_features: A dictionary within a the PyTorch DataSet iteration, which returns by the upstream DataLoader.iter() method In the DataLoader pipeline, all tensors belonging to all the ground truth changes are concatenated so it stays the same as monomer data input format/pipeline, thus, this function is needed to 1) detect the number of chains i.e. unique(asym_id) 2) split the concatenated tensors back to individual ones that correspond to individual asym_ids