Source code for recwizard.modules.kbrd.modeling_kbrd_rec

from typing import Union, List
from transformers.utils import ModelOutput

import torch

from recwizard.module_utils import BaseModule
from recwizard.utility import WrapSingleInput
from .configuration_kbrd_rec import KBRDRecConfig
from .tokenizer_kbrd_rec import KBRDRecTokenizer
from .entity_attention_encoder import KBRD

from recwizard.modules.monitor import monitor


[docs]class KBRDRec(BaseModule): """KBRDRec is a module that implements the KBRD recommender.""" config_class = KBRDRecConfig tokenizer_class = KBRDRecTokenizer LOAD_SAVE_IGNORES = {"encoder.text_encoder", "decoder"} def __init__( self, config: KBRDRecConfig, edge_index=None, edge_type=None, item_index=None, **kwargs, ): """Initialize the KBRDRec module. Args: config (KBRDRecConfig): The configuration of the KBRDRec module. edge_index (torch.LongTensor): The edges (node pairs) of the knowledge graph in KBRD, shape: (2, num_edges). edge_type (torch.LongTensor): The edge type of the knowledge graph in KBRD, shape: (num_edges,). item_index (torch.LongTensor): The recommendation item indices used in KBRD model, shape: (num_items,). """ super().__init__(config, **kwargs) # prepare weights edge_index = self.prepare_weight(edge_index, "edge_index") edge_type = self.prepare_weight(edge_type, "edge_type") item_index = self.prepare_weight(item_index, "item_index") # include the item_index into state_dict self.item_index = torch.nn.Parameter(item_index, requires_grad=False) # then load KBRD recommender self.model = KBRD( n_entity=config.n_entity, n_relation=config.n_relation, sub_n_relation=config.sub_n_relation, dim=config.dim, edge_idx=edge_index, edge_type=edge_type, num_bases=config.num_bases, ) # set the criterion self.criterion = torch.nn.CrossEntropyLoss() def forward( self, input_ids: torch.LongTensor, attention_mask: torch.BoolTensor, labels: torch.LongTensor = None, ): """Forward pass of the KBRDRec module. Args: input_ids (torch.LongTensor): The input ids of the input conversation contexts, shape: (batch_size, seq_len). attention_mask (torch.BoolTensor): The attention mask of the input, shape: (batch_size, seq_len). labels (torch.LongTensor): The labels of the converastions, optional. Returns: ModelOutput: The output of the model, containing the logits and the loss. """ scores = self.model(input_ids, attention_mask) loss = None if labels is None else self.criterion(scores, labels) return ModelOutput({"rec_logits": scores, "rec_loss": loss}) @WrapSingleInput @monitor def response( self, raw_input: Union[List[str], str], tokenizer: KBRDRecTokenizer, return_dict=False, topk=3, ): """Generate the textual response or results dict given the input_ids. Args: raw_input (Union[List[str], str]): The input conversation contexts. tokenizer (KBRDRecTokenizer): The tokenizer of the model. return_dict (bool): Whether to return the output as a dictionary. topk (int): The number of movies to recommend. Returns: Union[List[str], dict]: The dictionary of the model output with `logits`, `movieIds` and textual response if `return_dict` is `True`, else the textual model response only. """ entities = tokenizer(raw_input)["entities"].to(self.device) inputs = { "input_ids": entities, "attention_mask": entities != tokenizer.pad_entity_id, } logits = self.forward(**inputs)["rec_logits"] # mask all non-movie indices offset = 0 * logits.clone() + float("-inf") offset[:, self.item_index] = 0 logits += offset movieIds = logits.topk(k=topk, dim=1).indices.tolist() output = tokenizer.batch_decode(movieIds) if return_dict: return { "logits": logits, "movieIds": movieIds, "output": output, } return output