openfold.utils.import_weights¶
Classes
Functions
|
|
|
Update older OpenFold model weight names to match the current model code. |
|
|
|
|
|
Import model weights. |
|
|
|
|
- class Param(param: Union[torch.Tensor, List[torch.Tensor]], param_type: openfold.utils.import_weights.ParamType = <ParamType.Other: functools.partial(<function ParamType.<lambda> at 0x7f060d5e32e0>)>, stacked: bool = False, swap: bool = False)¶
- class ParamType(value)¶
Bases:
EnumAn enumeration.
- LinearBiasMHA = functools.partial(<function ParamType.<lambda>>)¶
- LinearBiasMultimer = functools.partial(<function ParamType.<lambda>>)¶
- LinearMHAOutputWeight = functools.partial(<function ParamType.<lambda>>)¶
- LinearWeight = functools.partial(<function ParamType.<lambda>>)¶
- LinearWeightMHA = functools.partial(<function ParamType.<lambda>>)¶
- LinearWeightMultimer = functools.partial(<function ParamType.<lambda>>)¶
- LinearWeightOPM = functools.partial(<function ParamType.<lambda>>)¶
- Other = functools.partial(<function ParamType.<lambda>>)¶
- assign(translation_dict, orig_weights)¶
- convert_deprecated_v1_keys(state_dict)¶
Update older OpenFold model weight names to match the current model code.
- generate_translation_dict(model, version, is_multimer=False)¶
- import_jax_weights_(model, npz_path, version='model_1')¶
- import_openfold_weights_(model, state_dict)¶
Import model weights. Several parts of the model were refactored in the process of adding support for Multimer. The state dicts of older models are translated to match the refactored model code.
- process_translation_dict(d, top_layer=True)¶
- stacked(param_dict_list, out=None)¶
- Parameters:
param_dict_list – A list of (nested) Param dicts to stack. The structure of each dict must be the identical (down to the ParamTypes of “parallel” Params). There must be at least one dict in the list.