openfold.utils.rigid_utils¶
Classes
|
A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a [*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch dimensions of its component parts. |
|
A 3D rotation. |
Functions
|
|
|
|
|
|
|
|
|
|
|
Multiply a quaternion by another quaternion. |
|
Multiply a quaternion by a pure-vector quaternion. |
|
Converts a quaternion to a rotation matrix. |
|
Performs matrix multiplication of two rotation matrix tensors. |
|
|
|
Applies a rotation to a vector. |
- class Rigid(rots, trans)¶
A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a [*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch dimensions of its component parts.
- apply(pts)¶
Applies the transformation to a coordinate tensor.
- apply_rot_fn(fn)¶
Applies a Rotation -> Rotation function to the stored rotation object.
- apply_trans_fn(fn)¶
Applies a Tensor -> Tensor function to the stored translation.
- static cat(ts, dim)¶
Concatenates transformations along a new dimension.
- compose(r)¶
Composes the current rigid object with another.
- compose_q_update_vec(q_update_vec)¶
Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.
- cuda()¶
Moves the transformation object to GPU memory
- Returns:
A version of the transformation on GPU
- Return type:
- static from_3_points(p_neg_x_axis, origin, p_xy_plane, eps=1e-08)¶
Implements algorithm 21. Constructs transformations from sets of 3 points using the Gram-Schmidt algorithm.
- static from_tensor_4x4(t)¶
Constructs a transformation from a homogenous transformation tensor.
- static from_tensor_7(t, normalize_quats=False)¶
- static identity(shape, dtype=None, device=None, requires_grad=True, fmt='quat')¶
Constructs an identity transformation.
- Parameters:
- Returns:
The identity transformation
- Return type:
- invert_apply(pts)¶
Applies the inverse of the transformation to a coordinate tensor.
- static make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20)¶
Returns a transformation object from reference coordinates.
Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You need to take care of such cases in your code.
- Parameters:
- Returns:
A transformation object. After applying the translation and rotation to the reference backbone, the coordinates will approximately equal to the input coordinates.
- map_tensor_fn(fn)¶
Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the translation/rotation dimensions respectively.
- scale_translation(trans_scale_factor)¶
Scales the translation by a constant factor.
- stop_rot_gradient()¶
Detaches the underlying rotation object
- Returns:
A transformation object with detached rotations
- Return type:
- to_tensor_4x4()¶
Converts a transformation to a homogenous transformation tensor.
- to_tensor_7()¶
Converts a transformation to a tensor with 7 final columns, four for the quaternion followed by three for the translation.
- unsqueeze(dim)¶
Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation.
- property device: device¶
Returns the device on which the Rigid’s tensors are located.
- Returns:
The device on which the Rigid’s tensors are located
- class Rotation(rot_mats=None, quats=None, normalize_quats=True)¶
A 3D rotation. Depending on how the object is initialized, the rotation is represented by either a rotation matrix or a quaternion, though both formats are made available by helper functions. To simplify gradient computation, the underlying format of the rotation cannot be changed in-place. Like Rigid, the class is designed to mimic the behavior of a torch Tensor, almost as if each Rotation object were a tensor of rotations, in one format or another.
- apply(pts)¶
Apply the current Rotation as a rotation matrix to a set of 3D coordinates.
- static cat(rs, dim)¶
Concatenates rotations along one of the batch dimensions. Analogous to torch.cat().
Note that the output of this operation is always a rotation matrix, regardless of the format of input rotations.
- compose_q(r, normalize_quats=True)¶
Compose the quaternions of the current Rotation object with those of another.
Depending on whether either Rotation was initialized with quaternions, this function may call torch.linalg.eigh.
- compose_q_update_vec(q_update_vec, normalize_quats=True)¶
Returns a new quaternion Rotation after updating the current object’s underlying rotation with a quaternion update, formatted as a [*, 3] tensor whose final three columns represent x, y, z such that (1, x, y, z) is the desired (not necessarily unit) quaternion update.
- compose_r(r)¶
Compose the rotation matrices of the current Rotation object with those of another.
- cuda()¶
Analogous to the cuda() method of torch Tensors
- Returns:
A copy of the Rotation in CUDA memory
- Return type:
- detach()¶
Returns a copy of the Rotation whose underlying Tensor has been detached from its torch graph.
- Returns:
A copy of the Rotation whose underlying Tensor has been detached from its torch graph
- Return type:
- get_cur_rot()¶
Return the underlying rotation in its current form
- Returns:
The stored rotation
- Return type:
- get_quats()¶
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a quaternion, this function may call torch.linalg.eigh.
- Returns:
The rotation as a quaternion tensor.
- Return type:
- get_rot_mats()¶
Returns the underlying rotation as a rotation matrix tensor.
- Returns:
The rotation as a rotation matrix tensor
- Return type:
- static identity(shape, dtype=None, device=None, requires_grad=True, fmt='quat')¶
Returns an identity Rotation.
- Parameters:
shape – The “shape” of the resulting Rotation object. See documentation for the shape property
dtype (dtype | None) – The torch dtype for the rotation
device (device | None) – The torch device for the new rotation
requires_grad (bool) – Whether the underlying tensors in the new rotation object should require gradient computation
fmt (str) – One of “quat” or “rot_mat”. Determines the underlying format of the new object’s rotation
- Returns:
A new identity rotation
- Return type:
- invert()¶
Returns the inverse of the current Rotation.
- Returns:
The inverse of the current Rotation
- Return type:
- invert_apply(pts)¶
The inverse of the apply() method.
- map_tensor_fn(fn)¶
Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can be used e.g. to sum out a one-hot batch dimension.
- to(device, dtype)¶
Analogous to the to() method of torch Tensors
- unsqueeze(dim)¶
Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object.
- property device: device¶
The device of the underlying rotation
- Returns:
The device of the underlying rotation
- property dtype: dtype¶
Returns the dtype of the underlying rotation.
- Returns:
The dtype of the underlying rotation
- property requires_grad: bool¶
Returns the requires_grad property of the underlying rotation
- Returns:
The requires_grad property of the underlying tensor
- property shape: Size¶
Returns the virtual shape of the rotation object. This shape is defined as the batch dimensions of the underlying rotation matrix or quaternion. If the Rotation was initialized with a [10, 3, 3] rotation matrix tensor, for example, the resulting shape would be [10].
- Returns:
The virtual shape of the rotation object
- identity_quats(batch_dims, dtype=None, device=None, requires_grad=True)¶
- identity_rot_mats(batch_dims, dtype=None, device=None, requires_grad=True)¶
- identity_trans(batch_dims, dtype=None, device=None, requires_grad=True)¶
- quat_multiply(quat1, quat2)¶
Multiply a quaternion by another quaternion.
- quat_multiply_by_vec(quat, vec)¶
Multiply a quaternion by a pure-vector quaternion.
- quat_to_rot(quat)¶
Converts a quaternion to a rotation matrix.
- rot_matmul(a, b)¶
Performs matrix multiplication of two rotation matrix tensors. Written out by hand to avoid AMP downcasting.