openfold.utils.rigid_utils

Classes

Rigid(rots, trans)

A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a [*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch dimensions of its component parts.

Rotation([rot_mats, quats, normalize_quats])

A 3D rotation.

Functions

identity_quats(batch_dims[, dtype, device, ...])

identity_rot_mats(batch_dims[, dtype, ...])

identity_trans(batch_dims[, dtype, device, ...])

invert_quat(quat)

invert_rot_mat(rot_mat)

quat_multiply(quat1, quat2)

Multiply a quaternion by another quaternion.

quat_multiply_by_vec(quat, vec)

Multiply a quaternion by a pure-vector quaternion.

quat_to_rot(quat)

Converts a quaternion to a rotation matrix.

rot_matmul(a, b)

Performs matrix multiplication of two rotation matrix tensors.

rot_to_quat(rot)

rot_vec_mul(r, t)

Applies a rotation to a vector.

class Rigid(rots, trans)

A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a [*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch dimensions of its component parts.

Parameters:
apply(pts)

Applies the transformation to a coordinate tensor.

Parameters:

pts (Tensor) – A [*, 3] coordinate tensor.

Returns:

The transformed points.

Return type:

Tensor

apply_rot_fn(fn)

Applies a Rotation -> Rotation function to the stored rotation object.

Parameters:

fn (Callable[[Rotation], Rotation]) – A function of type Rotation -> Rotation

Returns:

A transformation object with a transformed rotation.

Return type:

Rigid

apply_trans_fn(fn)

Applies a Tensor -> Tensor function to the stored translation.

Parameters:

fn (Callable[[Tensor], Tensor]) – A function of type Tensor -> Tensor to be applied to the translation

Returns:

A transformation object with a transformed translation.

Return type:

Rigid

static cat(ts, dim)

Concatenates transformations along a new dimension.

Parameters:
  • ts (Sequence[Rigid]) – A list of T objects

  • dim (int) – The dimension along which the transformations should be concatenated

Returns:

A concatenated transformation object

Return type:

Rigid

compose(r)

Composes the current rigid object with another.

Parameters:

r (Rigid) – Another Rigid object

Returns:

The composition of the two transformations

Return type:

Rigid

compose_q_update_vec(q_update_vec)

Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.

Parameters:
  • q_vec – The quaternion update vector.

  • q_update_vec (Tensor)

Returns:

The composed transformation.

Return type:

Rigid

cuda()

Moves the transformation object to GPU memory

Returns:

A version of the transformation on GPU

Return type:

Rigid

static from_3_points(p_neg_x_axis, origin, p_xy_plane, eps=1e-08)

Implements algorithm 21. Constructs transformations from sets of 3 points using the Gram-Schmidt algorithm.

Parameters:
  • p_neg_x_axis (Tensor) – [*, 3] coordinates

  • origin (Tensor) – [*, 3] coordinates used as frame origins

  • p_xy_plane (Tensor) – [*, 3] coordinates

  • eps (float) – Small epsilon value

Returns:

A transformation object of shape [*]

Return type:

Rigid

static from_tensor_4x4(t)

Constructs a transformation from a homogenous transformation tensor.

Parameters:

t (Tensor) – [*, 4, 4] homogenous transformation tensor

Returns:

T object with shape [*]

Return type:

Rigid

static from_tensor_7(t, normalize_quats=False)
Parameters:
Return type:

Rigid

get_rots()

Getter for the rotation.

Returns:

The rotation object

Return type:

Rotation

get_trans()

Getter for the translation.

Returns:

The stored translation

Return type:

Tensor

static identity(shape, dtype=None, device=None, requires_grad=True, fmt='quat')

Constructs an identity transformation.

Parameters:
  • shape (Tuple[int]) – The desired shape

  • dtype (dtype | None) – The dtype of both internal tensors

  • device (device | None) – The device of both internal tensors

  • requires_grad (bool) – Whether grad should be enabled for the internal tensors

  • fmt (str)

Returns:

The identity transformation

Return type:

Rigid

invert()

Inverts the transformation.

Returns:

The inverse transformation.

Return type:

Rigid

invert_apply(pts)

Applies the inverse of the transformation to a coordinate tensor.

Parameters:

pts (Tensor) – A [*, 3] coordinate tensor

Returns:

The transformed points.

Return type:

Tensor

static make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20)

Returns a transformation object from reference coordinates.

Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You need to take care of such cases in your code.

Parameters:
  • n_xyz – A [*, 3] tensor of nitrogen xyz coordinates.

  • ca_xyz – A [*, 3] tensor of carbon alpha xyz coordinates.

  • c_xyz – A [*, 3] tensor of carbon xyz coordinates.

Returns:

A transformation object. After applying the translation and rotation to the reference backbone, the coordinates will approximately equal to the input coordinates.

map_tensor_fn(fn)

Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the translation/rotation dimensions respectively.

Parameters:

fn (Callable[[Tensor], Tensor]) – A Tensor -> Tensor function to be mapped over the Rigid

Returns:

The transformed Rigid object

Return type:

Rigid

scale_translation(trans_scale_factor)

Scales the translation by a constant factor.

Parameters:

trans_scale_factor (float) – The constant factor

Returns:

A transformation object with a scaled translation.

Return type:

Rigid

stop_rot_gradient()

Detaches the underlying rotation object

Returns:

A transformation object with detached rotations

Return type:

Rigid

to_tensor_4x4()

Converts a transformation to a homogenous transformation tensor.

Returns:

A [*, 4, 4] homogenous transformation tensor

Return type:

Tensor

to_tensor_7()

Converts a transformation to a tensor with 7 final columns, four for the quaternion followed by three for the translation.

Returns:

A [*, 7] tensor representation of the transformation

Return type:

Tensor

unsqueeze(dim)

Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation.

Parameters:

dim (int) – A positive or negative dimension index.

Returns:

The unsqueezed transformation.

Return type:

Rigid

property device: device

Returns the device on which the Rigid’s tensors are located.

Returns:

The device on which the Rigid’s tensors are located

property dtype: dtype

Returns the dtype of the Rigid tensors.

Returns:

The dtype of the Rigid tensors

property shape: Size

Returns the shape of the shared dimensions of the rotation and the translation.

Returns:

The shape of the transformation

class Rotation(rot_mats=None, quats=None, normalize_quats=True)

A 3D rotation. Depending on how the object is initialized, the rotation is represented by either a rotation matrix or a quaternion, though both formats are made available by helper functions. To simplify gradient computation, the underlying format of the rotation cannot be changed in-place. Like Rigid, the class is designed to mimic the behavior of a torch Tensor, almost as if each Rotation object were a tensor of rotations, in one format or another.

Parameters:
apply(pts)

Apply the current Rotation as a rotation matrix to a set of 3D coordinates.

Parameters:

pts (Tensor) – A [*, 3] set of points

Returns:

[*, 3] rotated points

Return type:

Tensor

static cat(rs, dim)

Concatenates rotations along one of the batch dimensions. Analogous to torch.cat().

Note that the output of this operation is always a rotation matrix, regardless of the format of input rotations.

Parameters:
  • rs (Sequence[Rotation]) – A list of rotation objects

  • dim (int) – The dimension along which the rotations should be concatenated

Returns:

A concatenated Rotation object in rotation matrix format

Return type:

Rigid

compose_q(r, normalize_quats=True)

Compose the quaternions of the current Rotation object with those of another.

Depending on whether either Rotation was initialized with quaternions, this function may call torch.linalg.eigh.

Parameters:
  • r (Rotation) – An update rotation object

  • normalize_quats (bool)

Returns:

An updated rotation object

Return type:

Rotation

compose_q_update_vec(q_update_vec, normalize_quats=True)

Returns a new quaternion Rotation after updating the current object’s underlying rotation with a quaternion update, formatted as a [*, 3] tensor whose final three columns represent x, y, z such that (1, x, y, z) is the desired (not necessarily unit) quaternion update.

Parameters:
  • q_update_vec (Tensor) – A [*, 3] quaternion update tensor

  • normalize_quats (bool) – Whether to normalize the output quaternion

Returns:

An updated Rotation

Return type:

Rotation

compose_r(r)

Compose the rotation matrices of the current Rotation object with those of another.

Parameters:

r (Rotation) – An update rotation object

Returns:

An updated rotation object

Return type:

Rotation

cuda()

Analogous to the cuda() method of torch Tensors

Returns:

A copy of the Rotation in CUDA memory

Return type:

Rotation

detach()

Returns a copy of the Rotation whose underlying Tensor has been detached from its torch graph.

Returns:

A copy of the Rotation whose underlying Tensor has been detached from its torch graph

Return type:

Rotation

get_cur_rot()

Return the underlying rotation in its current form

Returns:

The stored rotation

Return type:

Tensor

get_quats()

Returns the underlying rotation as a quaternion tensor.

Depending on whether the Rotation was initialized with a quaternion, this function may call torch.linalg.eigh.

Returns:

The rotation as a quaternion tensor.

Return type:

Tensor

get_rot_mats()

Returns the underlying rotation as a rotation matrix tensor.

Returns:

The rotation as a rotation matrix tensor

Return type:

Tensor

static identity(shape, dtype=None, device=None, requires_grad=True, fmt='quat')

Returns an identity Rotation.

Parameters:
  • shape – The “shape” of the resulting Rotation object. See documentation for the shape property

  • dtype (dtype | None) – The torch dtype for the rotation

  • device (device | None) – The torch device for the new rotation

  • requires_grad (bool) – Whether the underlying tensors in the new rotation object should require gradient computation

  • fmt (str) – One of “quat” or “rot_mat”. Determines the underlying format of the new object’s rotation

Returns:

A new identity rotation

Return type:

Rotation

invert()

Returns the inverse of the current Rotation.

Returns:

The inverse of the current Rotation

Return type:

Rotation

invert_apply(pts)

The inverse of the apply() method.

Parameters:

pts (Tensor) – A [*, 3] set of points

Returns:

[*, 3] inverse-rotated points

Return type:

Tensor

map_tensor_fn(fn)

Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can be used e.g. to sum out a one-hot batch dimension.

Parameters:

fn (Callable[[Tensor], Tensor]) – A Tensor -> Tensor function to be mapped over the Rotation

Returns:

The transformed Rotation object

Return type:

Rotation

to(device, dtype)

Analogous to the to() method of torch Tensors

Parameters:
  • device (device | None) – A torch device

  • dtype (dtype | None) – A torch dtype

Returns:

A copy of the Rotation using the new device and dtype

Return type:

Rotation

unsqueeze(dim)

Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object.

Parameters:

dim (int) – A positive or negative dimension index.

Returns:

The unsqueezed Rotation.

Return type:

Rigid

property device: device

The device of the underlying rotation

Returns:

The device of the underlying rotation

property dtype: dtype

Returns the dtype of the underlying rotation.

Returns:

The dtype of the underlying rotation

property requires_grad: bool

Returns the requires_grad property of the underlying rotation

Returns:

The requires_grad property of the underlying tensor

property shape: Size

Returns the virtual shape of the rotation object. This shape is defined as the batch dimensions of the underlying rotation matrix or quaternion. If the Rotation was initialized with a [10, 3, 3] rotation matrix tensor, for example, the resulting shape would be [10].

Returns:

The virtual shape of the rotation object

identity_quats(batch_dims, dtype=None, device=None, requires_grad=True)
Parameters:
Return type:

Tensor

identity_rot_mats(batch_dims, dtype=None, device=None, requires_grad=True)
Parameters:
Return type:

Tensor

identity_trans(batch_dims, dtype=None, device=None, requires_grad=True)
Parameters:
Return type:

Tensor

invert_quat(quat)
Parameters:

quat (Tensor)

invert_rot_mat(rot_mat)
Parameters:

rot_mat (Tensor)

quat_multiply(quat1, quat2)

Multiply a quaternion by another quaternion.

quat_multiply_by_vec(quat, vec)

Multiply a quaternion by a pure-vector quaternion.

quat_to_rot(quat)

Converts a quaternion to a rotation matrix.

Parameters:

quat (Tensor) – [*, 4] quaternions

Returns:

[*, 3, 3] rotation matrices

Return type:

Tensor

rot_matmul(a, b)

Performs matrix multiplication of two rotation matrix tensors. Written out by hand to avoid AMP downcasting.

Parameters:
  • a (Tensor) – [*, 3, 3] left multiplicand

  • b (Tensor) – [*, 3, 3] right multiplicand

Returns:

The product ab

Return type:

Tensor

rot_to_quat(rot)
Parameters:

rot (Tensor)

rot_vec_mul(r, t)

Applies a rotation to a vector. Written out by hand to avoid transfer to avoid AMP downcasting.

Parameters:
  • r (Tensor) – [*, 3, 3] rotation matrices

  • t (Tensor) – [*, 3] coordinate tensors

Returns:

[*, 3] rotated coordinates

Return type:

Tensor