recwizard.BaseModule#
The BaseModule
class adds a few features to the Huggingface PreTrainedModel class.
- class recwizard.module_utils.BaseModule(config: BaseConfig | None = None, **kwargs)[source]#
The base class that’s used for any modules that’s based on nn.Module. Supporting partial load/save and state dict mapping.
- __init__(config: BaseConfig | None = None, **kwargs)[source]#
- Parameters:
config – config for PreTrainedModel
- get_tokenizer()[source]#
Serves as a default tokenizer getter for the module. Will load the tokenizer from self.model_name_or_path if it has been set.
Warning
This function will be called when a module is passed to a BasePipeline instance without providing the corresponding tokenizer.
Tip
If your tokenizer cannot be initialized by the default implementation, we recommend you to intialize and pass the tokenizer manually; or you can override this function (It accepts no arguments)
- prepare_weight(weight: array, name: str)[source]#
This function wraps a weight passed for initializing a module parameter and saves its shape and dtype as a in BaseConfig.WEIGHT_DIMENSIONS. So when
from_pretrained()
is called without passing the weight (i.e. when weight is None), the corresponding parameter is initialized from the saved shape and dtype beforeload_state_dict()
is called to copy the module’s weights.- Parameters:
weight – the weight that will be copied to a parameter of the module
name – a name to distinguish the weight in the module config
- Returns:
a pytorch tensor with the same value as weight if weight is not None, otherwise a zero tensor with the saved shape and dtype in the module config.
- response(**kwargs)#
The main function for the module to generate a response given an input.
Note
Please refer to our tutorial for implementation guidance: Overview
- Parameters:
raw_input (str) – the text input
tokenizer (PreTrainedTokenizer) – the tokenizer used to tokenize the input
return_dict (bool) – if set to True, will return a dict of outputs instead of a single output
**kwargs – the keyword arguments that will be passed to
forward()
- Returns:
By default, a single output will be returned. If return_dict is set to True, a dict of outputs will be returned.
- forward(*inputs, **kwargs)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- state_dict(*args, destination=None, prefix='', keep_vars=False)[source]#
Note
Adapted from pytorch’s original implementation to support
self.LOAD_SAVE_IGNORES
, when a module matches any pattern inself.LOAD_SAVE_IGNORES
, it will be ignored in the returned state dict.
Returns a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- load_state_dict(state_dict: OrderedDict[str, torch.TensorType], strict: bool = True, allow_unexpected=False, LOAD_PREFIX: str | None = '', LOAD_MAPPINGS: dict | None = None, LOAD_IGNORES: Iterable[str] | None = ())[source]#
Note
Adapted from pytorch’s original implementation to support partial loading and state dict mapping.
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters:
allow_unexpected – if set to True, will not complain about Unexpected key(s)
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
LOAD_PREFIX – When set, will load state dict from a prefix LOAD_PREFIX
LOAD_MAPPINGS – When set, for each key, val pair in the dict, any parameter in state_dict that has key as a prefix will be mapped to one that has val in place of key as a new prefix. (This parameter is useful when the submodules/parameters have different names from their counterparts in the state dict.)
LOAD_IGNORES – When set, the submodules in the state_dict that match any pattern in LOAD_IGNORES will be ignored in loading. Note that self.LOAD_SAVE_IGNORES will be merged into LOAD_IGNORES. (This option can be useful when you use a frozen pre-trained model inside your model to avoid duplicate loading and saving)
- Returns:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
- static map_parameters(name_mappings, state_dict, remove_origin=True)[source]#
Map parameters in state_dict to another name according to
name_mappings
.- Parameters:
name_mappings (dict) – a dict that maps the original parameter name to the new name.
state_dict (dict) – the state dict to be mapped.
remove_origin (bool) – whether to remove the original parameter.
- Returns:
the mapped state dict.
- Return type:
dict
- static remove_ignores(load_ignores: Iterable, state_dict)[source]#
Remove parameters in state_dict according to
load_ignores
.- Parameters:
load_ignores (list) – a list of prefixes to be ignored.
state_dict (dict) – the state dict to operate on.
- Returns:
the state dict after removing ignored parameters.
- Return type:
dict
- load_checkpoint(checkpoint, verbose=True, strict=True, **kwargs)[source]#
Load a checkpoint from a file.
- Parameters:
checkpoint (str) – the path to the checkpoint file.
verbose (bool) – whether to print the message when loading the checkpoint.
strict (bool) – the strict argument passed to load_state_dict.