Source code for recwizard.tokenizer_utils

import json
import os
import re
from typing import Optional, Tuple, Dict, List, Union, Callable

import torch
from transformers import PreTrainedTokenizer, AutoTokenizer
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from huggingface_hub import hf_hub_download

from recwizard.utility import SEP_TOKEN, BOS_TOKEN, EOS_TOKEN, ENTITY_PATTERN, loadJsonFileFromDataset, pad_and_stack

Entity = int


[docs]class BaseTokenizer(PreTrainedTokenizer): init_kwargs_file = 'tokenizer_kwargs.json' pattern_sep_token = re.compile(SEP_TOKEN) pattern_bos_token = re.compile(BOS_TOKEN) pattern_eos_token = re.compile(EOS_TOKEN)
[docs] def __init__(self, entity2id: Dict[str, int] = None, id2entity: Dict[int, str] = None, pad_entity_id: int = None, tokenizers: Union[List[PreTrainedTokenizerBase], PreTrainedTokenizerBase] = None, **kwargs): """ Args: entity2id (Dict[str, int]): a dict mapping entity name to entity id. If not provided, it will be generated from id2entity. id2entity (Dict[int, str]): a dict mapping entity id to entity name. If not provided, it will be generated from entity2id. pad_entity_id (int): the id for padding entity. If not provided, it will be the maximum entity id + 1. tokenizers (List[PreTrainedTokenizerBase]): a list of tokenizers to be used. **kwargs: other arguments for PreTrainedTokenizer """ super().__init__( **kwargs ) if entity2id is None and id2entity is not None: entity2id = {v: k for k, v in id2entity.items()} if id2entity is None and entity2id is not None: id2entity = {v: k for k, v in entity2id.items()} self.entity2id = entity2id self.id2entity = id2entity self.pad_entity_id = None if self.entity2id is not None: self.pad_entity_id = pad_entity_id if pad_entity_id is not None else max(self.entity2id.values()) + 1 if tokenizers is None: self.tokenizers = None else: self.tokenizers = tokenizers if isinstance(tokenizers, List) else [tokenizers] self.entity_pattern = re.compile(ENTITY_PATTERN)
[docs] @classmethod def load_from_dataset(cls, dataset='redial_unicrs', **kwargs): """ Initialize the tokenizer from the dataset. By default, it will load the entity2id from the dataset. Args: dataset: the dataset name **kwargs: the other arguments for initialization Returns: (BaseTokenizer): the initialized tokenizer """ entity2id = loadJsonFileFromDataset(dataset, 'entity2id.json') return cls(entity2id=entity2id, **kwargs)
[docs] def unk_token(self) -> str: """ Override this function if you want to change the unk_token. """ return "<unk>"
[docs] def unk_token_id(self) -> Optional[int]: """ Override this function if you want to change the unk_token. """ return -1
@property def vocab_size(self) -> int: try: return len(self.entity2id) except: return len(self.tokenizers[0]) def _convert_token_to_id(self, token: str) -> int: return self.entity2id[token] if token in self.entity2id else self.unk_token_id def _convert_id_to_token(self, index: int) -> str: return self.id2entity[index] if index in self.id2entity else self.unk_token
[docs] @staticmethod def mergeEncoding(encodings: List[BatchEncoding]) -> BatchEncoding: """ Merge a list of encodings into one encoding. Assumes each encoding has the same attributes other than data. """ if len(encodings) == 0: return BatchEncoding({}) res = encodings[0].copy() for encoding in encodings[1:]: res.data.update(encoding.data) return res
[docs] def replace_special_tokens(self, text: str) -> List[str]: """ Replace the cls token, sep token and eos token for each tokenizer Args: text: the text to be replaced Returns: (List[str]): a list of text, each used for one tokenizer """ def replace_for_tokenizer(text: str, tokenizer: PreTrainedTokenizerBase): text = self.replace_bos_token(text, tokenizer) text = self.replace_sep_token(text, tokenizer) text = self.replace_eos_token(text, tokenizer) return text return [replace_for_tokenizer(text, tokenizer) for tokenizer in self.tokenizers]
def replace_sep_token(self, text, tokenizer: PreTrainedTokenizerBase): token = str(tokenizer._sep_token or tokenizer.eos_token) text = self.pattern_sep_token.sub(token, text) return text def replace_bos_token(self, text, tokenizer: PreTrainedTokenizerBase): token = str(tokenizer._cls_token or '') text = self.pattern_bos_token.sub(token, text) return text def replace_eos_token(self, text, tokenizer: PreTrainedTokenizerBase): token = str(tokenizer._eos_token or ' ') text = self.pattern_eos_token.sub(token, text) return text
[docs] def encodes(self, encode_funcs: List[Callable], texts: List[Union[str, List[str]]], *args, **kwargs) -> List[ BatchEncoding]: """ This function is called to apply encoding functions from different tokenizers. It will be used by both `encode_plus` and `batch_encode_plus`. If you want to call different tokenizers with different arguments, override this method. Args: encode_funcs: the encoding functions from `self.tokenizers`. texts: the processed text for each encoding function **kwargs: Returns: a list of BatchEncoding, the length of the list is the same as the number of tokenizer """ assert len(encode_funcs) == len(texts) return [func(text, *args, **kwargs) for text, func in zip(texts, encode_funcs)]
[docs] def preprocess(self, text: str) -> str: """ Override this function to preprocess the text. It will be used by both `encode_plus` and `batch_encode_plus`. Args: text: the text to be preprocessed Returns: processed text """ return text
[docs] def batch_encode_plus( self, batch_text_or_text_pairs: List[str], *args, **kwargs ) -> BatchEncoding: """ Overrides the `batch_encode_plus` function from `PreTrainedTokenizer` to support entity processing. """ # preprocess batch_text = map(self.preprocess, batch_text_or_text_pairs) if self.entity2id: # process entity processed_results = [self.process_entities(text) for text in batch_text] processed_text, batch_entities = map(list, zip(*processed_results)) if kwargs.get('padding') == True: batch_entities = pad_and_stack([torch.tensor(entities, dtype=torch.long) for entities in batch_entities], pad_value=self.pad_entity_id) if kwargs.get('return_tensors') is None: batch_entities = batch_entities.tolist() if self.tokenizers is None: return BatchEncoding({ 'raw_text': processed_text, 'entities': batch_entities }) else: processed_text = batch_text # replace special tokens for each tokenzizer texts = [list(text) for text in zip(*[self.replace_special_tokens(s) for s in processed_text])] # call the encodes function on the list of tokenizer and the list of text encodings = self.encodes([tokenizer.batch_encode_plus for tokenizer in self.tokenizers], texts, *args, **kwargs) encodings = self.mergeEncoding(encodings) # add entities to the encodings if self.entity2id: encodings.data['entities'] = batch_entities return encodings
[docs] def encode_plus( self, text: str, *args, **kwargs ) -> BatchEncoding: """ Overrides the `encode_plus` function from `PreTrainedTokenizer` to support entity processing. """ text = self.preprocess(text) processed_text, entities = self.process_entities(text) if self.tokenizers is None: return BatchEncoding({ 'raw_text': processed_text, 'entities': entities }) texts = self.replace_special_tokens(processed_text) encodings = self.encodes([tokenizer.encode_plus for tokenizer in self.tokenizers], texts, *args, **kwargs) encodings = self.mergeEncoding(encodings) if kwargs.get('return_tensors') == 'pt': entities = torch.tensor(entities, dtype=torch.long) if self.entity2id: encodings.data['entities'] = entities return encodings
[docs] def process_entities(self, text: str) -> Tuple[str, List[Entity]]: """ Process the entities in the text. It extracts the entity ids from the text and remove the entity tags. """ entities = [] for m in reversed(list(self.entity_pattern.finditer(text))): # use reverse to avoid breaking the span start, end = m.span() entity_name = m.group(1).strip(" ") text = text[:start] + entity_name + text[end:] if self.entity2id is not None and entity_name in self.entity2id: entities.append(self.entity2id[entity_name]) entities = entities[::-1] return text, entities
[docs] def decode( self, *args, **kwargs, ) -> str: """ Overrides the `decode` function from `PreTrainedTokenizer`. By default, calls the `decode` function of the first tokenizer. """ return self.tokenizers[0].decode(*args, **kwargs)
[docs] def get_init_kwargs(self): """ The kwargs for initialization. Override this function to declare the necessary initialization kwargs ( they will be saved when the tokenizer is saved or pushed to huggingface model hub.) See also: :meth:`~save_vocabulary` """ return { 'entity2id': self.entity2id, 'pad_entity_id': self.pad_entity_id }
[docs] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: """ This method is overridden to save the initialization kwargs to the model directory. """ filename_prefix = filename_prefix + "-" if filename_prefix else "" vocab_path = os.path.join(save_directory, filename_prefix, self.init_kwargs_file) json.dump(self.get_init_kwargs(), open(vocab_path, 'w')) return str(vocab_path),
[docs] @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """ Load the tokenizer from pretrained model or local directory. It loads the initialization kwargs from the 'tokenizer_kwargs.json' file before initializing the tokenizer. """ try: path = hf_hub_download(pretrained_model_name_or_path, cls.init_kwargs_file) except: path = os.path.join(pretrained_model_name_or_path, cls.init_kwargs_file) with open(path, 'r') as f: init_kwargs = json.load(f) kwargs.update(init_kwargs) return cls(*args, **kwargs)