from typing import Union, List
from transformers.utils import ModelOutput
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.conv.rgcn_conv import RGCNConv
from torch_geometric.nn.conv.gcn_conv import GCNConv
import json
import pickle as pkl
import numpy as np
from recwizard import BaseModule
from .configuration_kgsf_gen import KGSFGenConfig
from .tokenizer_kgsf_gen import KGSFGenTokenizer
from .utils import _create_embeddings,_create_entity_embeddings, _edge_list, seed_everything
from .graph_utils import SelfAttentionLayer,SelfAttentionLayer_batch
from .transformer_utils import TransformerEncoder, TransformerDecoderKG
from recwizard.modules.monitor import monitor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
[docs]class KGSFGen(BaseModule):
config_class = KGSFGenConfig
tokenizer_class = KGSFGenTokenizer
[docs] def __init__(self, config, **kwargs):
super().__init__(config,**kwargs)
self.dictionary = config.dictionary
self.batch_size = config.batch_size
self.max_r_length = config.max_r_length
self.NULL_IDX = config.padding_idx
self.END_IDX = config.end_idx
self.longest_label = config.longest_label
self.register_buffer('START', torch.LongTensor([config.start_idx]))
self.pad_idx = config.padding_idx
self.embeddings = _create_embeddings(
self.dictionary, config.embedding_size, self.pad_idx, config.embedding_data
)
self.concept_embeddings = _create_entity_embeddings(
config.n_concept + 1, config.dim, 0)
self.concept_padding = 0
self.kg = {int(k): v for k, v in config.subkg.items()}
self.n_positions = config.n_positions
self.encoder = TransformerEncoder(
n_heads=config.n_heads,
n_layers=config.n_layers,
embedding_size=config.embedding_size,
ffn_size=config.ffn_size,
vocabulary_size=len(self.dictionary)+4,
embedding=self.embeddings,
dropout=config.dropout,
attention_dropout=config.attention_dropout,
relu_dropout=config.relu_dropout,
padding_idx=self.pad_idx,
learn_positional_embeddings=config.learn_positional_embeddings,
embeddings_scale=config.embeddings_scale,
reduction=False,
n_positions=self.n_positions,
)
self.decoder = TransformerDecoderKG(
n_heads=config.n_heads,
n_layers=config.n_layers,
embedding_size=config.embedding_size,
ffn_size=config.embedding_size,
vocabulary_size=len(self.dictionary)+4,
embedding=self.embeddings,
dropout=config.dropout,
attention_dropout=config.attention_dropout,
relu_dropout=config.relu_dropout,
padding_idx=self.pad_idx,
learn_positional_embeddings=config.learn_positional_embeddings,
embeddings_scale=config.embeddings_scale,
n_positions=int(self.n_positions),
)
self.db_norm = nn.Linear(config.dim, config.embedding_size)
self.kg_norm = nn.Linear(config.dim, config.embedding_size)
self.db_attn_norm=nn.Linear(config.dim,config.embedding_size)
self.kg_attn_norm=nn.Linear(config.dim,config.embedding_size)
self.criterion = nn.CrossEntropyLoss(reduce=False)
self.self_attn = SelfAttentionLayer_batch(config.dim, config.dim)
self.self_attn_db = SelfAttentionLayer(config.dim, config.dim)
self.user_norm = nn.Linear(config.dim * 2, config.dim)
self.gate_norm = nn.Linear(config.dim, 1)
self.copy_norm = nn.Linear(config.embedding_size * 2 + config.embedding_size, config.embedding_size)
self.representation_bias = nn.Linear(config.embedding_size, len(self.dictionary) + 4)
self.output_en = nn.Linear(config.dim, config.n_entity)
self.embedding_size = config.embedding_size
self.dim = config.dim
edge_list, self.n_relation = _edge_list(self.kg, config.n_entity, hop=2)
edge_list = list(set(edge_list))
print(len(edge_list), self.n_relation)
self.dbpedia_edge_sets = torch.LongTensor(edge_list).to(device)
self.db_edge_idx = self.dbpedia_edge_sets[:, :2].t()
self.db_edge_type = self.dbpedia_edge_sets[:, 2]
self.dbpedia_RGCN = RGCNConv(config.n_entity, self.dim, self.n_relation, num_bases=config.num_bases)
self.concept_edge_sets = torch.LongTensor(config.edge_set).to(device)#_concept_edge_list4GCN(config)#
self.concept_GCN = GCNConv(self.dim, self.dim)
self.mask4key = torch.Tensor(np.array(config.mask4key)).to(device)
self.mask4movie = torch.Tensor(np.array(config.mask4movie)).to(device)
self.mask4=self.mask4key+self.mask4movie
params = [self.dbpedia_RGCN.parameters(), self.concept_GCN.parameters(),
self.concept_embeddings.parameters(),
self.self_attn.parameters(), self.self_attn_db.parameters(), self.user_norm.parameters(),
self.gate_norm.parameters(), self.output_en.parameters()]
for param in params:
for pa in param:
pa.requires_grad = False
self.pretrain = config.pretrain
[docs] def _starts(self, bsz):
"""Return bsz start tokens."""
return self.START.detach().expand(bsz, 1)
[docs] def decode_greedy(self, encoder_states, encoder_states_kg, encoder_states_db, attention_kg, attention_db, bsz, maxlen):
"""
Greedy search
:param int bsz:
Batch size. Because encoder_states is model-specific, it cannot
infer this automatically.
:param encoder_states:
Output of the encoder model.
:type encoder_states:
Model specific
:param int maxlen:
Maximum decoding length
:return:
pair (logits, choices) of the greedy decode
:rtype:
(FloatTensor[bsz, maxlen, vocab], LongTensor[bsz, maxlen])
"""
xs = self._starts(bsz)
logits = []
for i in range(maxlen):
scores = self.decoder(xs, encoder_states, encoder_states_kg, encoder_states_db)
scores = scores[:, -1:, :]
kg_attn_norm = self.kg_attn_norm(attention_kg)
db_attn_norm = self.db_attn_norm(attention_db)
copy_latent = self.copy_norm(torch.cat([kg_attn_norm.unsqueeze(1), db_attn_norm.unsqueeze(1), scores], -1))
con_logits = self.representation_bias(copy_latent)*self.mask4.unsqueeze(0).unsqueeze(0)
voc_logits = F.linear(scores, self.embeddings.weight)
sum_logits = voc_logits + con_logits
_, preds = sum_logits.max(dim=-1)
logits.append(sum_logits)
xs = torch.cat([xs, preds], dim=1)
# check if everyone has generated an end token
all_finished = ((xs == self.END_IDX).sum(dim=1) > 0).sum().item() == bsz
if all_finished:
break
logits = torch.cat(logits, 1)
return logits, xs
[docs] def decode_forced(self, encoder_states, encoder_states_kg, encoder_states_db, attention_kg, attention_db, ys):
"""
Decode with a fixed, true sequence, computing loss. Useful for
training, or ranking fixed candidates.
:param ys:
the prediction targets. Contains both the start and end tokens.
:type ys:
LongTensor[bsz, time]
:param encoder_states:
Output of the encoder. Model specific types.
:type encoder_states:
model specific
:return:
pair (logits, choices) containing the logits and MLE predictions
:rtype:
(FloatTensor[bsz, ys, vocab], LongTensor[bsz, ys])
"""
bsz = ys.size(0)
seqlen = ys.size(1)
inputs = ys.narrow(1, 0, seqlen - 1)
inputs = torch.cat([self._starts(bsz), inputs], 1)
latent, _ = self.decoder(inputs, encoder_states, encoder_states_kg, encoder_states_db) #batch*r_l*hidden
kg_attention_latent=self.kg_attn_norm(attention_kg)
db_attention_latent=self.db_attn_norm(attention_db)
copy_latent=self.copy_norm(torch.cat([kg_attention_latent.unsqueeze(1).repeat(1,seqlen,1), db_attention_latent.unsqueeze(1).repeat(1,seqlen,1), latent],-1))
con_logits = self.representation_bias(copy_latent)*self.mask4.unsqueeze(0).unsqueeze(0)#F.linear(copy_latent, self.embeddings.weight)
logits = F.linear(latent, self.embeddings.weight)
sum_logits = logits+con_logits
_, preds = sum_logits.max(dim=2)
return logits, preds
def compute_loss(self, output, scores):
score_view = scores.view(-1)
output_view = output.view(-1, output.size(-1))
loss = self.criterion(output_view.to(device), score_view.to(device))
return loss
[docs] def forward(self, context, response, concept_mask, seed_sets, entity_vector, entity=None, cand_params=None, prev_enc=None, maxlen=None,
bsz=None):
if response is not None:
self.longest_label = max(self.longest_label, response.size(1))
else:
maxlen=20
if bsz == None:
bsz = len(seed_sets)
encoder_states = prev_enc if prev_enc is not None else self.encoder(context)
# graph network---------------------------------------------------------------------------------------------------
db_nodes_features = self.dbpedia_RGCN(None, self.db_edge_idx, self.db_edge_type)
con_nodes_features=self.concept_GCN(self.concept_embeddings.weight,self.concept_edge_sets)
user_representation_list = []
db_con_mask=[]
for i, seed_set in enumerate(seed_sets):
if seed_set == []:
user_representation_list.append(torch.zeros(self.dim).cuda())
db_con_mask.append(torch.zeros([1]))
continue
user_representation = db_nodes_features[seed_set] # torch can reflect
user_representation = self.self_attn_db(user_representation)
user_representation_list.append(user_representation)
db_con_mask.append(torch.ones([1]))
db_user_emb=torch.stack(user_representation_list)
db_con_mask=torch.stack(db_con_mask)
graph_con_emb=con_nodes_features[concept_mask]
con_emb_mask=concept_mask==self.concept_padding
con_user_emb=graph_con_emb
con_user_emb,attention=self.self_attn(con_user_emb,con_emb_mask.cuda())
#generation---------------------------------------------------------------------------------------------------
con_nodes_features4gen=con_nodes_features
con_emb4gen = con_nodes_features4gen[concept_mask]
con_mask4gen = concept_mask != self.concept_padding
kg_encoding=(self.kg_norm(con_emb4gen),con_mask4gen.cuda())
db_emb4gen=db_nodes_features[entity_vector]
db_mask4gen=entity_vector!=0
db_encoding=(self.db_norm(db_emb4gen),db_mask4gen.cuda())
if response is not None: # if self.test == False
scores, preds = self.decode_forced(encoder_states, kg_encoding, db_encoding, con_user_emb, db_user_emb, response)
gen_loss = torch.mean(self.compute_loss(scores, response))
else:
scores, preds = self.decode_greedy(
encoder_states, kg_encoding, db_encoding, con_user_emb, db_user_emb,
bsz,
maxlen or self.longest_label
)
gen_loss = torch.tensor(1.0)
loss = gen_loss
result = {'loss': loss, 'preds': preds, 'response': response, 'context': context}
return result
@monitor
def response(self, raw_input: str, tokenizer, return_dict=False):
tokenized_input = tokenizer.encode(raw_input)
output = self.forward(**tokenized_input)['preds']
decoded_text = tokenizer.decode(output)
return decoded_text