Source code for recwizard.modules.unicrs.tokenizer_unicrs_rec

from typing import Dict, List, Callable, Union
from transformers import AutoTokenizer, BatchEncoding

from recwizard.utility.utils import WrapSingleInput, loadJsonFileFromDataset
from recwizard.tokenizer_utils import BaseTokenizer

gpt2_special_tokens_dict = {
    'pad_token': '<pad>',
    'additional_special_tokens': ['<movie>'],
}

prompt_special_tokens_dict = {
    'additional_special_tokens': ['<movie>'],
}


[docs]class UnicrsRecTokenizer(BaseTokenizer):
[docs] def __init__( self, context_tokenizer: str = "microsoft/DialoGPT-small", prompt_tokenizer: str = "roberta-base", context_max_length: int = 200, prompt_max_length: int = 200, pad_entity_id: int = 31161, entity2id: Dict[str, int] = None, **kwargs, ): context_tokenizer = AutoTokenizer.from_pretrained(context_tokenizer, truncation_side='left') prompt_tokenizer = AutoTokenizer.from_pretrained(prompt_tokenizer, truncation_side='left') context_tokenizer.add_special_tokens(gpt2_special_tokens_dict) prompt_tokenizer.add_special_tokens(prompt_special_tokens_dict) tokenizers = [ context_tokenizer, prompt_tokenizer ] super().__init__(tokenizers=tokenizers, entity2id=entity2id, pad_entity_id=pad_entity_id, **kwargs) self.context_max_length = context_max_length self.prompt_max_length = prompt_max_length
[docs] def __call__(self, *args, **kwargs): kwargs.update(return_tensors='pt', padding=True, truncation=True) return super().__call__(*args, **kwargs)
[docs] @classmethod def load_from_dataset(cls, dataset='redial_unicrs', **kwargs): entityName2id = loadJsonFileFromDataset(dataset, 'entityName2id.json') entity2id = loadJsonFileFromDataset(dataset, 'entity2id.json') return cls(entity2id=entityName2id, pad_entity_id=max(entity2id.values()) + 1, **kwargs)
[docs] @staticmethod def mergeEncoding(encodings: List[BatchEncoding]) -> BatchEncoding: res = encodings[0].copy() if len(encodings) == 0: return res res.data = { 'context': encodings[0], 'prompt': encodings[1] } return res
[docs] def encodes(self, encode_funcs: List[Callable], texts: List[Union[str, List[str]]], *args, **kwargs) -> List[ BatchEncoding]: kwargs1 = kwargs.copy() kwargs1.update( max_length=self.context_max_length ) kwargs2 = kwargs.copy() kwargs2.update( max_length=self.prompt_max_length ) return [ encode_funcs[0](texts[0], *args, **kwargs1), encode_funcs[1](texts[1], *args, **kwargs2), ]
@WrapSingleInput def decode( self, ids, *args, **kwargs, ) -> List[str]: return [self.id2entity[id] for id in ids]