SwitchDecode#

class recwizard.pipelines.switch_decode.configuration_switch_decode.SwitchDecodeConfig(hidden_size=256, context_size=256, max_seq_length=40, rec_pattern: str = '<movie>', resp_prompt='System:', **kwargs)[source]#
__init__(hidden_size=256, context_size=256, max_seq_length=40, rec_pattern: str = '<movie>', resp_prompt='System:', **kwargs)[source]#
Parameters:
  • WEIGHT_DIMENSIONS (dict, optional) – The dimension and dtype of module parameters. Used to initialize the parameters when they are not explicitly specified in module initialization. Defaults to None. See also recwizard.module_utils.BaseModule.prepare_weight().

  • **kwargs – Additional parameters. Will be passed to the PretrainedConfig.__init__.

class recwizard.pipelines.switch_decode.modeling_switch_decode.SwitchDecodePipeline(config, temperature=1, sample_movies=False, **kwargs)[source]#
__init__(config, temperature=1, sample_movies=False, **kwargs)[source]#

Initialize the model with the recommender/generator modules and the corresponding tokenizers.

If the tokenizers are not passed, the model will try to get the tokenizers by calling get_tokenizer on the modules.

Parameters:
  • config (BaseConfig) – The config file as is used in transformers.PreTrainedModel.

  • rec_module (BaseModule) – The recommender module.

  • rec_tokenizer (BaseTokenizer) – The tokenizer for the recommender module.

  • gen_module (BaseModule) – The generator module.

  • gen_tokenizer (BaseTokenizer) – The tokenizer for the generator module.

  • **kwargs – The other keyword arguments used for transformers.PreTrainedModel.

forward(rec_inputs, gen_inputs, forbid_movies)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

replace_movie_with_words(tokens)[source]#

If the ID corresponds to a movie, returns the sequence of tokens that correspond to this movie name :param tokens:

Returns: modified sequence

response(query: str, beam_size=10, forbid_movies=None, temperature=1, **kwargs)[source]#

The main function for the model to generate a response given a query.

Parameters:
  • query (str) – The formatted dialogue history. See the format at First Example.

  • return_dict (bool) – if set to True, will return a dict of outputs instead of a single output.

  • rec_args (dict) – The arguments passed to the recommender module.

  • gen_args (dict) – The arguments passed to the generator module.