import copy
import logging
import os
from collections import defaultdict, OrderedDict
from typing import Tuple, Callable, Optional, Iterable, List, Type, Union
import numpy as np
import torch
from torch.nn.modules.module import _IncompatibleKeys
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer
from .modules.monitor import monitor
from .configuration_utils import BaseConfig
[docs]class BaseModule(PreTrainedModel):
"""
The base class that's used for any modules that's based on nn.Module. Supporting partial load/save and state dict mapping.
"""
LOAD_SAVE_IGNORES = set()
"""
The set of external modules which will be skipped both in `state_dict()` and `load_state_dict()`.
e.g. When a fixed PretrainedModel is used as an encoder, we can set ``LOAD_SAVE_IGNORES={"encoder"}`` to skip
loading the encoder.
.. tip::
We suggest using it when your module has a fixed part, in order to reduce the size of checkpoint.
"""
tokenizer_class: Type[PreTrainedTokenizer] = AutoTokenizer
"""
The tokenizer class that should be used for this module.
"""
config_class = BaseConfig
[docs] def __init__(self, config: BaseConfig = None, **kwargs):
"""
Args:
config: config for PreTrainedModel
"""
super().__init__(config, **kwargs)
self.model_name_or_path = None
[docs] def get_tokenizer(self):
"""
Serves as a default tokenizer getter for the module.
Will load the tokenizer from `self.model_name_or_path` if it has been set.
.. warning::
This function will be called when a module is passed to a `BasePipeline` instance
without providing the corresponding tokenizer.
.. tip::
If your tokenizer cannot be initialized by the default implementation,
we recommend you to intialize and pass the tokenizer manually;
or you can override this function (**It accepts no arguments**)
"""
if self.model_name_or_path is not None:
try:
return self.tokenizer_class.from_pretrained(self.model_name_or_path)
except:
raise EnvironmentError(
f'Please make sure "{self.model_name_or_path}" is a repo or path where the tokenizer is saved')
else:
return None
[docs] def prepare_weight(self, weight: np.array, name: str):
"""
This function wraps a weight passed for initializing a module parameter and saves its `shape` and `dtype` as a
in `BaseConfig.WEIGHT_DIMENSIONS`.
So when :func:`~from_pretrained` is called without passing the weight (i.e. when `weight` is None),
the corresponding parameter is initialized from the saved `shape` and `dtype`
before :func:`~load_state_dict` is called to copy the module's weights.
Args:
weight: the weight that will be copied to a parameter of the module
name: a name to distinguish the weight in the module config
Returns:
a pytorch tensor with the same value as `weight` if `weight` is not None, otherwise a zero tensor with
the saved `shape` and `dtype` in the module config.
"""
if weight is None:
return torch.zeros(self.config.WEIGHT_DIMENSIONS[name + ".shape"],
dtype=eval(self.config.WEIGHT_DIMENSIONS[name + ".dtype"]))
else:
weight = torch.as_tensor(weight)
self.config.WEIGHT_DIMENSIONS[name + ".shape"] = list(weight.shape)
self.config.WEIGHT_DIMENSIONS[name + ".dtype"] = str(weight.dtype)
return weight
@monitor
def response(self, raw_input, tokenizer, return_dict=False, **kwargs):
r"""
The main function for the module to generate a response given an input.
.. note::
Please refer to our tutorial for implementation guidance: :doc:`/development/overview`
Args:
raw_input (str): the text input
tokenizer (PreTrainedTokenizer): the tokenizer used to tokenize the input
return_dict (bool): if set to True, will return a dict of outputs instead of a single output
**kwargs: the keyword arguments that will be passed to :func:`~forward`
Returns:
By default, a single output will be returned. If `return_dict` is set to True, a dict of outputs will be returned.
"""
tokenized_input = tokenizer(raw_input, return_tensors="pt")
logits = self.forward(tokenized_input, **kwargs)
output = tokenizer.batch_decode(logits, skip_special_tokens=True)
if return_dict:
return {
"input": raw_input,
"tokenized_input": tokenized_input,
"logits": logits,
"output": output
}
else:
return output
[docs] def forward(self, *inputs, **kwargs):
raise NotImplementedError
[docs] def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
r"""
.. note::
Adapted from pytorch's original implementation to support :attr:`self.LOAD_SAVE_IGNORES`,
when a module matches any pattern in :attr:`self.LOAD_SAVE_IGNORES`, it will be ignored in the returned state dict.
--------------------------------
Returns a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
.. note::
The returned object is a shallow copy. It contains references
to the module's parameters and buffers.
.. warning::
Currently ``state_dict()`` also accepts positional arguments for
``destination``, ``prefix`` and ``keep_vars`` in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
.. warning::
Please avoid the use of argument ``destination`` as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an ``OrderedDict`` will be created and returned.
Default: ``None``.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ``''``.
keep_vars (bool, optional): by default the :class:`~torch.Tensor` s
returned in the state dict are detached from autograd. If it's
set to ``True``, detaching will not be performed.
Default: ``False``.
Returns:
dict:
a dictionary containing a whole state of the module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
"""
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if name in self.LOAD_SAVE_IGNORES:
logging.debug(f"Ignored state dict for {prefix + name}")
continue
module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
# remove chained destination
destination = self.remove_ignores(self.LOAD_SAVE_IGNORES, destination)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
[docs] def load_state_dict(self,
state_dict: 'OrderedDict[str, torch.TensorType]',
strict: bool = True,
allow_unexpected=False,
LOAD_PREFIX: Optional[str] = '',
LOAD_MAPPINGS: Optional[dict] = None,
LOAD_IGNORES: Optional[Iterable[str]] = tuple(),
):
r"""
.. note::
Adapted from pytorch's original implementation to support partial loading and state dict mapping.
--------------------------------
Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
Args:
allow_unexpected: if set to True, will not complain about Unexpected key(s)
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
LOAD_PREFIX: When set, will load state dict from a prefix `LOAD_PREFIX`
LOAD_MAPPINGS: When set, for each `key`, `val` pair in the dict, any parameter in state_dict
that has `key` as a prefix will be mapped to one that has `val` in place of `key` as a new prefix.
(This parameter is useful when the submodules/parameters have different names from their counterparts
in the state dict.)
LOAD_IGNORES: When set, the submodules in the state_dict that match any pattern in `LOAD_IGNORES`
will be ignored in loading. Note that `self.LOAD_SAVE_IGNORES` will be merged into `LOAD_IGNORES`.
(This option can be useful when you use a frozen pre-trained model inside your model
to avoid duplicate loading and saving)
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
"""
LOAD_IGNORES = self.LOAD_SAVE_IGNORES | set(LOAD_IGNORES)
if LOAD_MAPPINGS is None:
LOAD_MAPPINGS = dict()
state_dict = self.map_parameters(LOAD_MAPPINGS, state_dict)
state_dict = self.remove_ignores(LOAD_IGNORES, state_dict)
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined]
def load(module, prefix=''):
if prefix.strip('.') in LOAD_IGNORES: # skip chained LOAD_IGNORE. e.g. processor.encoder
return
if module.__dict__.get(
"LOAD_IGNORES") and "" in module.LOAD_IGNORES: # skip the whole module when specified
return
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if module.__dict__.get(
"LOAD_IGNORES") and name in module.LOAD_IGNORES: # skip submodules when specified
continue
if child is not None:
load(child, prefix + name + '.')
return
load(self, prefix=LOAD_PREFIX)
del load
if not allow_unexpected and len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
if strict:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
else: # NOTE: I modified the original implementation for debug
for msg in error_msgs:
logging.debug(msg)
return _IncompatibleKeys(missing_keys, unexpected_keys)
[docs] @staticmethod
def map_parameters(name_mappings, state_dict, remove_origin=True):
"""
Map parameters in state_dict to another name according to ``name_mappings``.
Args:
name_mappings (dict): a dict that maps the original parameter name to the new name.
state_dict (dict): the state dict to be mapped.
remove_origin (bool): whether to remove the original parameter.
Returns:
dict: the mapped state dict.
"""
new_dict = {}
for key, value in list(state_dict.items()):
has_mapping = False
for prefix, mapped_prefix in name_mappings.items():
if key.startswith(prefix) and prefix != mapped_prefix:
new_key = mapped_prefix + key[len(prefix):]
new_dict[new_key] = copy.deepcopy(value)
if not remove_origin:
new_dict[key] = value
has_mapping = True
break
if not has_mapping:
new_dict[key] = value
return new_dict
[docs] @staticmethod
def remove_ignores(load_ignores: Iterable, state_dict):
"""
Remove parameters in state_dict according to ``load_ignores``.
Args:
load_ignores (list): a list of prefixes to be ignored.
state_dict (dict): the state dict to operate on.
Returns:
dict: the state dict after removing ignored parameters.
"""
used_prefix = set()
for key in list(state_dict.keys()):
for prefix in load_ignores:
if key.startswith(prefix):
used_prefix.add(prefix)
state_dict.pop(key)
break
for prefix in used_prefix:
logging.debug(f"Ignored state dict for {prefix}")
return state_dict
[docs] def load_checkpoint(self, checkpoint, verbose=True, strict=True, **kwargs):
"""
Load a checkpoint from a file.
Args:
checkpoint (str): the path to the checkpoint file.
verbose (bool): whether to print the message when loading the checkpoint.
strict (bool): the strict argument passed to `load_state_dict`.
"""
state_dict = torch.load(checkpoint, map_location=self.device)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
self.load_state_dict(state_dict, strict, **kwargs)
if verbose:
logging.info(f"{self.__class__.__name__} loaded checkpoint '{checkpoint}'")
[docs] @classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
**kwargs,
):
"""
Saves the `pretrained_model_name_or_path` as a member variable of the model, so when
`get_tokenizer` is called, the tokenizer can be initialized with `model_name_or_path`
"""
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model.model_name_or_path = pretrained_model_name_or_path
return model
# def bind_weights(module1: torch.nn.Module, module2: torch.nn.Module):
# """
# Bind the weights of two PyTorch modules so they share the same memory space.
#
# Args:
# module1 (torch.nn.Module): the first module.
# module2 (torch.nn.Module): the second module.
# """
# for param1, param2 in zip(module1.parameters(), module2.parameters()):
# param2.data.set_(param1.data)
# param2.requires_grad = param1.requires_grad