Source code for recwizard.modules.redial.tokenizer_rnn

from typing import List

from nltk import load, NLTKWordTokenizer
from tokenizers import Tokenizer, NormalizedString, PreTokenizedString
from tokenizers.models import WordLevel
from tokenizers.normalizers import Lowercase
from tokenizers.pre_tokenizers import PreTokenizer
from transformers import PreTrainedTokenizerFast


[docs]class NLTKTokenizer:
[docs] def __init__(self, language="english"): # nltk.download('punkt') self.tokenizer = load(f"tokenizers/punkt/{language}.pickle") self.word_tokenizer = NLTKWordTokenizer()
def word_tokenize(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]: return [normalized_string[s:t] for s, t in self.word_tokenizer.span_tokenize(str(normalized_string))] def nltk_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]: sentences = [normalized_string[s:t] for s, t in self.tokenizer.span_tokenize(str(normalized_string))] tokenized = [] for sentence in sentences: tokenized += self.word_tokenize(i, sentence) return tokenized def pre_tokenize(self, pretok: PreTokenizedString): # Let's call split on the PreTokenizedString to split using `self.jieba_split` pretok.split(self.nltk_split)
tokenizers = {}
[docs]def get_tokenizer(name="redial"): return tokenizers[name]
[docs]def RnnTokenizer(vocab, name="redial"): """ Return a tokenizer for RNN models from the given vocabulary Args: vocab(List[str]): list of words name(str): name of the tokenizer. Used to cache the tokenizer Returns: PreTrainedTokenizerFast """ if tokenizers.get(name): return tokenizers[name] word2id = {word: i for i, word in enumerate(vocab)} tokenizer = Tokenizer(WordLevel(unk_token='<unk>', vocab=word2id)) tokenizer.normalizer = Lowercase() wrapped_tokenizer = PreTrainedTokenizerFast( tokenizer_object=tokenizer, unk_token="<unk>", pad_token="<pad>", bos_token="<s>", eos_token="</s>", ) wrapped_tokenizer.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(NLTKTokenizer()) tokenizers[name] = wrapped_tokenizer return wrapped_tokenizer
if __name__ == "__main__": tokenizer = RnnTokenizer(["i", "eat", "apple", "rice"] + ['<pad>', '<s>', '</s>', '<unk>', '\n']) encoding = tokenizer.encode(["I eat apple today.", "I eat rice today."], padding=True, truncation=True, return_token_type_ids=False, return_tensors='pt', ) print(encoding)