Source code for recwizard.modules.kgsf.modeling_kgsf_rec

from typing import Union, List
from transformers.utils import ModelOutput
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.conv.rgcn_conv import RGCNConv
from torch_geometric.nn.conv.gcn_conv import GCNConv
import json
import pickle as pkl
from recwizard import BaseModule
from recwizard.utility.utils import WrapSingleInput
from .tokenizer_kgsf_rec import KGSFRecTokenizer
from .configuration_kgsf_rec import KGSFRecConfig
from .utils import _create_embeddings,_create_entity_embeddings, _edge_list
from .graph_utils import SelfAttentionLayer,SelfAttentionLayer_batch
from recwizard.modules.monitor import monitor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
[docs]class KGSFRec(BaseModule): config_class = KGSFRecConfig tokenizer_class = KGSFRecTokenizer LOAD_SAVE_IGNORES = {'START', 'decoder'}
[docs] def __init__(self, config, **kwargs): super().__init__(config,**kwargs) dictionary = config.dictionary self.batch_size = config.batch_size self.max_r_length = config.max_r_length self.NULL_IDX = config.padding_idx self.END_IDX = config.end_idx self.longest_label = config.longest_label self.pad_idx = config.padding_idx self.embeddings = _create_embeddings( dictionary, config.embedding_size, self.pad_idx, config.embedding_data ) self.concept_embeddings = _create_entity_embeddings( config.n_concept + 1, config.dim, 0) self.concept_padding = 0 self.kg = {int(k): v for k, v in config.subkg.items()} self.n_positions = config.n_positions self.criterion = nn.CrossEntropyLoss(reduce=False) self.self_attn = SelfAttentionLayer_batch(config.dim, config.dim) self.self_attn_db = SelfAttentionLayer(config.dim, config.dim) self.user_norm = nn.Linear(config.dim * 2, config.dim) self.gate_norm = nn.Linear(config.dim, 1) self.copy_norm = nn.Linear(config.embedding_size * 2 + config.embedding_size, config.embedding_size) self.representation_bias = nn.Linear(config.embedding_size, len(dictionary) + 4) self.info_con_norm = nn.Linear(config.dim, config.dim) self.info_db_norm = nn.Linear(config.dim, config.dim) self.info_output_db = nn.Linear(config.dim, config.n_entity) self.info_output_con = nn.Linear(config.dim, config.n_concept + 1) self.info_con_loss = nn.MSELoss(size_average=False, reduce=False) self.info_db_loss = nn.MSELoss(size_average=False, reduce=False) self.output_en = nn.Linear(config.dim, config.n_entity) self.embedding_size = config.embedding_size self.dim = config.dim edge_list, self.n_relation = _edge_list(self.kg, config.n_entity, hop=2) edge_list = list(set(edge_list)) self.dbpedia_edge_sets = torch.LongTensor(edge_list).to(device) self.db_edge_idx = self.dbpedia_edge_sets[:, :2].t() self.db_edge_type = self.dbpedia_edge_sets[:, 2] self.dbpedia_RGCN = RGCNConv(config.n_entity, self.dim, self.n_relation, num_bases=config.num_bases) self.concept_edge_sets = torch.LongTensor(config.edge_set).to(device) self.concept_GCN = GCNConv(self.dim, self.dim) self.pretrain = config.pretrain
def infomax_loss(self, db_nodes_features, con_user_emb, db_label, mask): con_emb=self.info_con_norm(con_user_emb) db_scores = F.linear(con_emb, db_nodes_features, self.info_output_db.bias) info_db_loss=torch.sum(self.info_db_loss(db_scores,db_label.to(device).float()),dim=-1)*mask.to(device) return torch.mean(info_db_loss) def get_total_loss(self, rec_loss, info_db_loss): if self.pretrain: return info_db_loss else: return rec_loss+0.025*info_db_loss
[docs] def forward(self, response, concept_mask, seed_sets, labels, db_vec, rec, test=True, cand_params=None, prev_enc=None, maxlen=None, bsz=None): print(type(concept_mask),type(seed_sets),type(db_vec),type(rec)) if bsz == None: bsz = len(seed_sets) # graph network db_nodes_features = self.dbpedia_RGCN(None, self.db_edge_idx, self.db_edge_type) # entity_encoder con_nodes_features=self.concept_GCN(self.concept_embeddings.weight,self.concept_edge_sets) # word_encoder user_representation_list = [] db_con_mask=[] for i, seed_set in enumerate(seed_sets): if len(seed_set) == 0: user_representation_list.append(torch.zeros(self.dim).to(device)) db_con_mask.append(torch.zeros([1])) continue user_representation = db_nodes_features[seed_set] # torch can reflect user_representation = self.self_attn_db(user_representation) user_representation_list.append(user_representation) db_con_mask.append(torch.ones([1])) db_user_emb=torch.stack(user_representation_list) db_con_mask=torch.stack(db_con_mask) graph_con_emb=con_nodes_features[concept_mask] con_emb_mask=concept_mask==self.concept_padding con_user_emb=graph_con_emb con_user_emb,attention=self.self_attn(con_user_emb,con_emb_mask.to(device)) user_emb=self.user_norm(torch.cat([con_user_emb,db_user_emb],dim=-1)) uc_gate = F.sigmoid(self.gate_norm(user_emb)) user_emb = uc_gate * db_user_emb + (1 - uc_gate) * con_user_emb entity_scores = F.linear(user_emb, db_nodes_features, self.output_en.bias) if response is not None: info_db_loss = self.infomax_loss(db_nodes_features,con_user_emb,db_vec,db_con_mask) rec_loss=self.criterion(entity_scores.squeeze(1).squeeze(1).float(), labels.to(device)) rec_loss = torch.sum(rec_loss*rec.float().to(device)) loss = self.get_total_loss(rec_loss,info_db_loss) labels = torch.tensor(labels) rec_loss = torch.tensor(rec_loss) else: loss = None rec_loss = None result = {'loss': loss, 'labels': labels, 'rec_scores': torch.tensor(entity_scores), 'rec_loss': rec_loss} print(result,labels) return result
@WrapSingleInput @monitor def response(self, raw_input: Union[List[str], str], tokenizer, return_dict=False, topk=3): movieIds_lst = [] movieNames_lst = [] for raw_single in raw_input: inputs = tokenizer.encode(raw_single) logits = self.forward(**inputs)['rec_scores'] movieIds, movieNames = tokenizer.decode(logits) movieIds_lst.append(movieIds) movieNames_lst.append(movieNames) if return_dict: return { 'logits': logits, 'movieIds': movieIds_lst, 'movieNames': movieNames_lst, } return movieNames_lst