openfold.utils.tensor_utils

Functions

add(m1, m2, inplace)

batched_gather(data, inds[, dim, no_batch_dims])

dict_map(fn, dic, leaf_type)

dict_multimap(fn, dicts)

flatten_final_dims(t, no_dims)

masked_mean(mask, value, dim[, eps])

maybe_to(x, dtype)

one_hot(x, v_bins)

permute_final_dims(tensor, inds)

pts_to_distogram(pts[, min_bin, max_bin, ...])

tree_map(fn, tree, leaf_type)

add(m1, m2, inplace)
batched_gather(data, inds, dim=0, no_batch_dims=0)
dict_map(fn, dic, leaf_type)
dict_multimap(fn, dicts)
flatten_final_dims(t, no_dims)
Parameters:
masked_mean(mask, value, dim, eps=0.0001)
maybe_to(x, dtype)
one_hot(x, v_bins)
permute_final_dims(tensor, inds)
Parameters:
pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64)
tensor_tree_map(fn, tree, *, leaf_type=<class 'torch.Tensor'>)
tree_map(fn, tree, leaf_type)