Source code for recwizard.modules.llm.tokenizer_llama
from typing import Optional
from typing_extensions import override
from transformers import AutoTokenizer
from recwizard.tokenizer_utils import BaseTokenizer
from recwizard.utility import SEP_TOKEN
[docs]class LlamaTokenizer(BaseTokenizer):
"""
The tokenizer for the generator based on OpenAI's GPT models.
"""
[docs] def __init__(self, **kwargs):
"""
Initializes the instance of this tokenizer.
"""
super().__init__(tokenizers=[
AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf"),
])
self.tokenizers[0].pad_token = self.tokenizers[0].eos_token
self.tokenizers[0].pad_token_id = self.tokenizers[0].eos_token_id
@property
def eos_token(self):
return self.tokenizers[0].eos_token
@property
def eos_token_id(self) -> int:
return self.tokenizers[0].eos_token_id
[docs] @override
def preprocess(self, raw_input, **kwargs):
"""
Process the raw input by extracting the pure text.
Args:
context (str): The raw input.
Returns:
"""
texts = raw_input.split(SEP_TOKEN)
user = 'User:'
system = 'System:'
context = ''
for text in texts:
if text.startswith(system):
context += text[len(system):].strip(' ') + '</s>'
else:
context += '<s>' + '[INST]' + text[len(user):].strip(' ') + '[/INST]'
return context