import json
import pickle as pkl
import numpy as np
import torch
from nltk import word_tokenize
import re
from typing import Dict, List, Callable, Union
from transformers import AutoTokenizer, BatchEncoding
from recwizard.utility.utils import WrapSingleInput, loadJsonFileFromDataset
from recwizard.tokenizer_utils import BaseTokenizer
from .utils import seed_everything
seed_everything(42)
[docs]class KGSFRecTokenizer(BaseTokenizer):
[docs] def __init__(
self,
max_count: int = 5,
max_c_length: int = 256,
max_r_length: int = 30,
n_entity: int = 64368,
batch_size: int = 1,
padding_idx: int = 0,
entity2entityId: Dict[str, int] = None,
word2index: Dict[str, int] = None,
key2index: Dict[str,int] = None,
entity_ids: List = None,
id2name: Dict[str,str] = None,
id2entity: Dict[int,str] = None,
entity2id: Dict[str, int] = None,
**kwargs,
):
if entity2entityId is None:
self.entity2entityId=pkl.load(open('recwizard/modules/kgsf/data/entity2entityId.pkl','rb'))
else:
self.entity2entityId = entity2entityId
self.entity_max=len(self.entity2entityId)
if word2index is None:
self.word2index = json.load(open('recwizard/modules/kgsf/data/word2index_redial.json', encoding='utf-8'))
else:
self.word2index = word2index
if key2index is None:
self.key2index=json.load(open('recwizard/modules/kgsf/data/key2index_3rd.json',encoding='utf-8'))
else:
self.key2index = key2index
if entity_ids is None:
self.entity_ids = pkl.load(open("recwizard/modules/kgsf/data/movie_ids.pkl", "rb"))
else:
self.entity_ids = entity_ids
if id2entity is None:
self.id2entity = pkl.load(open('recwizard/modules/kgsf/data/id2entity.pkl','rb'))
else:
self.id2entity = id2entity
self.id2entity = {int(k):str(v) for k,v in self.id2entity.items()}
if id2name is None:
self.id2name = json.load(open('recwizard/modules/kgsf/data/id2name.jsonl', encoding='utf-8'))
else:
self.id2name = id2name
if entity2id is None:
self.entity2id = {v:k for k,v in self.id2entity.items()}
else:
self.entity2id = entity2id
self.entityId2entity = {v:k for k,v in self.entity2entityId.items()}
# in config:
self.max_count = max_count
self.max_c_length = max_c_length
self.max_r_length = max_r_length
self.n_entity = n_entity
self.batch_size = batch_size
self.pad_entity_id = padding_idx
self.names2id = {v:k for k,v in self.id2name.items()}
super().__init__(entity2id=self.entity2id, pad_entity_id=self.pad_entity_id, **kwargs)
[docs] def get_init_kwargs(self):
"""
The kwargs for initialization. They will be saved when you save the tokenizer or push it to huggingface model hub.
"""
return {
'entity2entityId': self.entity2entityId,
'word2index': self.word2index,
'key2index': self.key2index,
'entity_ids': self.entity_ids,
'id2entity': self.id2entity,
'id2name': self.id2name,
}
[docs] def padding_w2v(self,sentence,max_length,pad=0,end=2,unk=3):
"""
sentence: ['Okay', ',', 'have', 'you', 'seen', '@136983', '?'] / [...]
max_length: 30 / 256
"""
vector=[]
concept_mask=[]
dbpedia_mask=[]
for word in sentence:
vector.append(self.word2index.get(word,unk))
concept_mask.append(self.key2index.get(word.lower(),0))
if '@' in word:
try:
entity = self.id2entity[int(word[1:])]
id=self.entity2entityId[entity]
except:
id=self.entity_max
dbpedia_mask.append(id)
else:
dbpedia_mask.append(self.entity_max)
vector.append(end)
concept_mask.append(0)
dbpedia_mask.append(self.entity_max)
if len(vector)>max_length:
return vector[:max_length],max_length,concept_mask[:max_length],dbpedia_mask[:max_length]
else:
length=len(vector)
return vector+(max_length-len(vector))*[pad],length,\
concept_mask+(max_length-len(vector))*[0],dbpedia_mask+(max_length-len(vector))*[self.entity_max]
[docs] def padding_context(self,contexts,pad=0):
"""
contexts: eg. [['Hello'], ['hi', 'how', 'are', 'u'], ['Great', '.', 'How', 'are', 'you', 'this', 'morning', '?'], ['would', 'u', 'have', 'any', 'recommendations', 'for', 'me', 'im', 'good', 'thanks', 'fo', 'asking'], ['What', 'type', 'of', 'movie', 'are', 'you', 'looking', 'for', '?'], ['comedies', 'i', 'like', 'kristin', 'wigg'], ['Okay', ',', 'have', 'you', 'seen', '@136983', '?'], ['something', 'like', 'yes', 'have', 'watched', '@140066', '?']]
"""
contexts_com=[]
for sen in contexts[-self.max_count:-1]: # get the most recent max_count of contexts
contexts_com.extend(sen)
contexts_com.append('_split_')
contexts_com.extend(contexts[-1])
vec,v_l,concept_mask,dbpedia_mask=self.padding_w2v(contexts_com,self.max_c_length)
return concept_mask,dbpedia_mask
def _names_to_id(self, input_name):
processed_input_name = input_name.strip().lower()
processed_input_name = re.sub(r'\(\d{4}\)', '', processed_input_name)
for name, id in self.names2id.items():
processed_name = name.strip().lower()
if processed_input_name in processed_name:
return id
return None
def detect_movie(self, sentence):
# This regular expression pattern will match text surrounded by <movie> tags
pattern = r"<entity>.*?</entity>|\w+|[.,!?;]"
print(sentence)
tokens = re.findall(pattern, sentence)
# Replace movie names with corresponding IDs in the tokens
movie_rec_trans = []
for i, token in enumerate(tokens):
if token.startswith('<entity>') and token.endswith('</entity>'):
movie_name = token[len('<entity>'):-len('</entity>')]
movie_id = self._names_to_id(movie_name)
if movie_id is not None:
tokens[i] = f"@{movie_id}"
try:
entity = self.id2entity[int(movie_id)]
entity_id = self.entity2entityId[entity]
movie_rec_trans.append(entity_id)
except:
pass
return tokens, movie_rec_trans
[docs] def encode(self,user_input, user_context=None, entity=None, system_response=None, movie=0):
"""
user_input: eg. Hi, can you recommend a movie for me?
user_context: eg. [['Hello'], ['hi', 'how', 'are', 'u']] TODO: 考虑分隔符吗 _split_?
entity: movies in user_context, default []
system_response: eg. ['Great', '.', 'How', 'are', 'you', 'this', 'morning', '?']
movie: movies in system_response, defualt is an ID, so None. ??? TODO: 多个movie的话 case会重复 tokenizer怎么解决?
"""
if user_context is None:
user_context, entity = self.detect_movie(user_input)
print(user_context,entity)
user_context = [user_context]
entity = entity[::-1]
entity = torch.tensor(np.array(entity)).unsqueeze(0)
concept_mask,dbpedia_mask=self.padding_context(user_context)
db_vec = np.zeros(self.n_entity)
for db in dbpedia_mask:
if db != 0:
db_vec[db] = 1
concept_mask = torch.tensor(np.array(concept_mask)).unsqueeze(0)
db_vec = torch.tensor(np.array(db_vec)).unsqueeze(0)
if system_response is not None:
response,r_length,_,_=self.padding_w2v(system_response,self.max_r_length)
response = torch.tensor(np.array(response)).unsqueeze(0)
return {
'response':response,
'concept_mask': concept_mask,
'seed_sets': entity,
'labels': movie,
'db_vec': db_vec,
'rec': int(movie!=0),
}
return {
'response':None,
'concept_mask': concept_mask,
'seed_sets': entity,
'labels': None,
'db_vec': db_vec,
'rec': 0,
}
[docs] def decode(self, outputs, top_k=3, labels=None):
outputs = outputs[:, torch.LongTensor(self.entity_ids)] #previous outputs.shape=(1,64368), now 6924, same as len(entity_ids)
_, pred_idx = torch.topk(outputs, k=100, dim=1)
movieIds = []
movieNames = []
for i in range(top_k):
top_pred = pred_idx[0][i]
entity_id = int(self.entity_ids[top_pred])
entity = self.entityId2entity[entity_id]
movie_id = self.entity2id[entity]
movie_name = self.id2name[str(movie_id)]
movieIds.append(movie_id)
movieNames.append(movie_name)
return movieIds, movieNames
if __name__ == "__main__":
tokenizer = KGSFRecTokenizer()
# test_sentence = 'She also voiced a character in @131869 and @161259'
# test_sentence = test_sentence.split(' ')
test_sentence = 'She also voiced a character in <entity>Despicable Me 3 (2017)</entity> and <entity>How to Train Your Dragon </entity> '
print(tokenizer.encode(test_sentence))
# test_sentence_user = ['Hello', '_split_', 'hi', 'how', 'are', 'u', '_split_', 'Great', '.', 'How', 'are', 'you', 'this', 'morning', '?', '_split_', 'would', 'u', 'have', 'any', 'recommendations', 'for', 'me', 'im', 'good', 'thanks', 'fo', 'asking']
# vector, ml, concept_mask, db_mask = tokenizer.padding_w2v(test_sentence,30)
# print(vector,ml,concept_mask,db_mask)
test_conv = [['Hello'], ['hi', 'how', 'are', 'u'], ['Great', '.', 'How', 'are', 'you', 'this', 'morning', '?'], ['would', 'u', 'have', 'any', 'recommendations', 'for', 'me', 'im', 'good', 'thanks', 'fo', 'asking'], ['What', 'type', 'of', 'movie', 'are', 'you', 'looking', 'for', '?'], ['comedies', 'i', 'like', 'kristin', 'wigg'], ['Okay', ',', 'have', 'you', 'seen', '@136983', '?'], ['something', 'like', 'yes', 'have', 'watched', '@140066', '?']]
#concept_mask,dbpedia_mask = tokenizer.padding_context(test_conv)
# print(vec,vl,concept_mask,dbpedia_mask,length)
result = tokenizer.encode(test_conv,[28207, 5771],test_sentence,22727)
print(result)
print('passed')