openfold.utils.import_weights

Classes

Param(param, List[torch.Tensor]], ...)

ParamType(value)

An enumeration.

Functions

assign(translation_dict, orig_weights)

convert_deprecated_v1_keys(state_dict)

Update older OpenFold model weight names to match the current model code.

generate_translation_dict(model, version[, ...])

import_jax_weights_(model, npz_path[, version])

import_openfold_weights_(model, state_dict)

Import model weights.

process_translation_dict(d[, top_layer])

stacked(param_dict_list[, out])

param param_dict_list:

A list of (nested) Param dicts to stack. The structure of

class Param(param: Union[torch.Tensor, List[torch.Tensor]], param_type: openfold.utils.import_weights.ParamType = <ParamType.Other: functools.partial(<function ParamType.<lambda> at 0x7f060d5e32e0>)>, stacked: bool = False, swap: bool = False)
Parameters:
param: Tensor | List[Tensor]
param_type: ParamType = functools.partial(<function ParamType.<lambda>>)
stacked: bool = False
swap: bool = False
class ParamType(value)

Bases: Enum

An enumeration.

LinearBiasMHA = functools.partial(<function ParamType.<lambda>>)
LinearBiasMultimer = functools.partial(<function ParamType.<lambda>>)
LinearMHAOutputWeight = functools.partial(<function ParamType.<lambda>>)
LinearWeight = functools.partial(<function ParamType.<lambda>>)
LinearWeightMHA = functools.partial(<function ParamType.<lambda>>)
LinearWeightMultimer = functools.partial(<function ParamType.<lambda>>)
LinearWeightOPM = functools.partial(<function ParamType.<lambda>>)
Other = functools.partial(<function ParamType.<lambda>>)
assign(translation_dict, orig_weights)
convert_deprecated_v1_keys(state_dict)

Update older OpenFold model weight names to match the current model code.

generate_translation_dict(model, version, is_multimer=False)
import_jax_weights_(model, npz_path, version='model_1')
import_openfold_weights_(model, state_dict)

Import model weights. Several parts of the model were refactored in the process of adding support for Multimer. The state dicts of older models are translated to match the refactored model code.

process_translation_dict(d, top_layer=True)
stacked(param_dict_list, out=None)
Parameters:

param_dict_list – A list of (nested) Param dicts to stack. The structure of each dict must be the identical (down to the ParamTypes of “parallel” Params). There must be at least one dict in the list.