recwizard.BasePipeline#
The BasePipeline
class adds a few features to the Huggingface PreTrainedModel class.
- class recwizard.model_utils.BasePipeline(config: BaseConfig | None = None, rec_module: BaseModule | None = None, gen_module: BaseModule | None = None, rec_tokenizer: BaseTokenizer | None = None, gen_tokenizer: BaseModule | None = None, **kwargs)[source]#
The base model for recwizard.
- __init__(config: BaseConfig | None = None, rec_module: BaseModule | None = None, gen_module: BaseModule | None = None, rec_tokenizer: BaseTokenizer | None = None, gen_tokenizer: BaseModule | None = None, **kwargs)[source]#
Initialize the model with the recommender/generator modules and the corresponding tokenizers.
If the tokenizers are not passed, the model will try to get the tokenizers by calling get_tokenizer on the modules.
- Parameters:
config (BaseConfig) – The config file as is used in transformers.PreTrainedModel.
rec_module (BaseModule) – The recommender module.
rec_tokenizer (BaseTokenizer) – The tokenizer for the recommender module.
gen_module (BaseModule) – The generator module.
gen_tokenizer (BaseTokenizer) – The tokenizer for the generator module.
**kwargs – The other keyword arguments used for transformers.PreTrainedModel.
- response(**kwargs)#
The main function for the model to generate a response given a query.
- Parameters:
query (str) – The formatted dialogue history. See the format at First Example.
return_dict (bool) – if set to True, will return a dict of outputs instead of a single output.
rec_args (dict) – The arguments passed to the recommender module.
gen_args (dict) – The arguments passed to the generator module.
- forward(*inputs, **kwargs)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.