recwizard.BaseModule#

The BaseModule class adds a few features to the Huggingface PreTrainedModel class.

class recwizard.module_utils.BaseModule(config: BaseConfig | None = None, **kwargs)[source]#

The base class that’s used for any modules that’s based on nn.Module. Supporting partial load/save and state dict mapping.

__init__(config: BaseConfig | None = None, **kwargs)[source]#
Parameters:

config – config for PreTrainedModel

get_tokenizer()[source]#

Serves as a default tokenizer getter for the module. Will load the tokenizer from self.model_name_or_path if it has been set.

Warning

This function will be called when a module is passed to a BasePipeline instance without providing the corresponding tokenizer.

Tip

If your tokenizer cannot be initialized by the default implementation, we recommend you to intialize and pass the tokenizer manually; or you can override this function (It accepts no arguments)

prepare_weight(weight: array, name: str)[source]#

This function wraps a weight passed for initializing a module parameter and saves its shape and dtype as a in BaseConfig.WEIGHT_DIMENSIONS. So when from_pretrained() is called without passing the weight (i.e. when weight is None), the corresponding parameter is initialized from the saved shape and dtype before load_state_dict() is called to copy the module’s weights.

Parameters:
  • weight – the weight that will be copied to a parameter of the module

  • name – a name to distinguish the weight in the module config

Returns:

a pytorch tensor with the same value as weight if weight is not None, otherwise a zero tensor with the saved shape and dtype in the module config.

response(**kwargs)#

The main function for the module to generate a response given an input.

Note

Please refer to our tutorial for implementation guidance: Overview

Parameters:
  • raw_input (str) – the text input

  • tokenizer (PreTrainedTokenizer) – the tokenizer used to tokenize the input

  • return_dict (bool) – if set to True, will return a dict of outputs instead of a single output

  • **kwargs – the keyword arguments that will be passed to forward()

Returns:

By default, a single output will be returned. If return_dict is set to True, a dict of outputs will be returned.

forward(*inputs, **kwargs)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

state_dict(*args, destination=None, prefix='', keep_vars=False)[source]#

Note

Adapted from pytorch’s original implementation to support self.LOAD_SAVE_IGNORES, when a module matches any pattern in self.LOAD_SAVE_IGNORES, it will be ignored in the returned state dict.


Returns a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
load_state_dict(state_dict: OrderedDict[str, torch.TensorType], strict: bool = True, allow_unexpected=False, LOAD_PREFIX: str | None = '', LOAD_MAPPINGS: dict | None = None, LOAD_IGNORES: Iterable[str] | None = ())[source]#

Note

Adapted from pytorch’s original implementation to support partial loading and state dict mapping.


Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Parameters:
  • allow_unexpected – if set to True, will not complain about Unexpected key(s)

  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • LOAD_PREFIX – When set, will load state dict from a prefix LOAD_PREFIX

  • LOAD_MAPPINGS – When set, for each key, val pair in the dict, any parameter in state_dict that has key as a prefix will be mapped to one that has val in place of key as a new prefix. (This parameter is useful when the submodules/parameters have different names from their counterparts in the state dict.)

  • LOAD_IGNORES – When set, the submodules in the state_dict that match any pattern in LOAD_IGNORES will be ignored in loading. Note that self.LOAD_SAVE_IGNORES will be merged into LOAD_IGNORES. (This option can be useful when you use a frozen pre-trained model inside your model to avoid duplicate loading and saving)

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

static map_parameters(name_mappings, state_dict, remove_origin=True)[source]#

Map parameters in state_dict to another name according to name_mappings.

Parameters:
  • name_mappings (dict) – a dict that maps the original parameter name to the new name.

  • state_dict (dict) – the state dict to be mapped.

  • remove_origin (bool) – whether to remove the original parameter.

Returns:

the mapped state dict.

Return type:

dict

static remove_ignores(load_ignores: Iterable, state_dict)[source]#

Remove parameters in state_dict according to load_ignores.

Parameters:
  • load_ignores (list) – a list of prefixes to be ignored.

  • state_dict (dict) – the state dict to operate on.

Returns:

the state dict after removing ignored parameters.

Return type:

dict

load_checkpoint(checkpoint, verbose=True, strict=True, **kwargs)[source]#

Load a checkpoint from a file.

Parameters:
  • checkpoint (str) – the path to the checkpoint file.

  • verbose (bool) – whether to print the message when loading the checkpoint.

  • strict (bool) – the strict argument passed to load_state_dict.

classmethod from_pretrained(pretrained_model_name_or_path: str | PathLike | None, *model_args, **kwargs)[source]#

Saves the pretrained_model_name_or_path as a member variable of the model, so when get_tokenizer is called, the tokenizer can be initialized with model_name_or_path