3. NEW Generator#

A NEW recommender can be used individually in this way:

3.1. Raw text input and output#

System: Hello!<sep>
User: Hi. I like horror movies, such as <entity>The Shining (1980)</entity> and <entity>Annabelle (2014)</entity>.
Would you please recommend me some other movies?

3.2. Tensor input and output#

# TODO

3.3. Implementation of NEW generator#

3.3.1. Generator Configuration: NEWGenConfig

from recwizard.configuration_utils import BaseConfig

class NEWGenConfig(BaseConfig):
    """Configuration class to sotre the
    configuration of the NEW Generator."""

    def __init__(
        self, base_model: str = "microsoft/DialoGPT-small", n_items: int = None, max_gen_len=100, **kwargs
    ):
        super().__init__(**kwargs)

        self.base_model = base_model
        self.n_items = n_items
        self.max_gen_len = max_gen_len

# use it!
config = NEWGenConfig(base_model='microsoft/DialoGPT-small', n_items=8)

3.3.2. Generator Tokenizer: NEWGenTokenizer

from recwizard.tokenizer_utils import BaseTokenizer


class NEWGenTokenizer(BaseTokenizer):
    """Tokenizer class for the NEW generator."""

    def __call__(self, *args, **kwargs):
        """Tokenize a string into a list of token ids."""
        kwargs.update(return_tensors="pt", padding=True, truncation=True)
        return super().__call__(*args, **kwargs)

# use it!
word_tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-small")
word_tokenizer.pad_token = word_tokenizer.eos_token
tokenizer = NEWGenTokenizer(
    tokenizers=[word_tokenizer],
    id2entity={
        0: "21 Bridges (2019)",
        1: "The Shining (1980)",
        2: "Annabelle (2014)",
        3: "The Conjuring (2013)",
        4: "The Exorcist (1973)",
        5: "The Conjuring 2 (2016)",
        6: "The Nun (2018)",
        7: "X men (2019)",
    },
)

3.3.3. Generator Module: NEWGen

import torch

from recwizard.module_utils import BaseModule
from recwizard.utility.utils import WrapSingleInput
from transformers import GPT2LMHeadModel


class NEWGen(BaseModule):
    """NEW is a module that implements the NEW generator."""

    config_class = NEWGenConfig
    tokenizer_class = NEWGenTokenizer

    def __init__(self, config: NEWGenConfig, **kwargs):
        super().__init__(config, **kwargs)

        self.gpt2_model = GPT2LMHeadModel.from_pretrained(config.base_model)
        self.entity_embeds = torch.nn.Embedding(
            config.n_items, self.gpt2_model.config.n_embd
        )
        self.max_gen_len = config.max_gen_len

    def generate(self, context, entities, attention_mask, **kwargs):
        """Forward pass of the NEW generator."""

        embeds = self.entity_embeds(entities)
        avg_embeds = embeds.sum(dim=1, keepdim=True) / (
            attention_mask.sum(dim=-1, keepdim=True) + 1e-8
        )
        text_embeds = self.gpt2_model.transformer.wte(context["input_ids"])
        inputs_embeds = torch.cat([avg_embeds, text_embeds], dim=1)
        attention_mask = torch.cat(
            [
                torch.ones(*avg_embeds.shape[:2]).to(avg_embeds.device),
                context["attention_mask"],
            ],
            dim=1,
        )

        return self.gpt2_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=self.max_gen_len,
            return_dict_in_generate=True,
            **kwargs,
        )

    @WrapSingleInput
    def response(self, raw_input, tokenizer, return_dict=False):
        """Generate response from the NEW generator."""

        inputs = tokenizer(raw_input)
        context = {
            "input_ids": inputs["input_ids"].to(self.device),
            "attention_mask": inputs["attention_mask"].to(self.device),
        }

        # convert text input to entity input
        entities = inputs["entities"].to(self.device)
        attention_mask = entities != tokenizer.pad_entity_id

        # generate response
        generated = self.generate(
            context=context, entities=entities, attention_mask=attention_mask
        )
        output = tokenizer.batch_decode(generated.sequences)

        # return the output
        if return_dict:
            return {"output": output, "input": raw_input, "generated": generated}
        return output

# use it!
model = NEWGen(config)

query = (
    "System: Hello!"
    "<sep>User: Hi. I like horror movies, such as <entity>The Shining (1980)</entity> and <entity>Annabelle (2014)</entity>."
    "Would you please recommend me some other movies?"
)

resp = model.response(raw_input=query, tokenizer=tokenizer, return_dict=True)