Source code for recwizard.modules.unicrs.kg_prompt

import math
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import RGCNConv
from transformers import AutoModel

from recwizard.utility import deterministic_seed, Singleton


# Adapted from https://github.com/RUCAIBox/UniCRS/blob/main/src/model_prompt.py

[docs]class KGPrompt(nn.Module):
[docs] def __init__( self, hidden_size, token_hidden_size, n_head, n_layer, n_block, num_bases, kg_info, edge_index, edge_type, num_tokens, n_prefix_rec=None, n_prefix_conv=None ): super().__init__() num_entities = kg_info['num_entities'] num_relations = kg_info['num_relations'] self.text_encoder = Singleton('unicrs.KGPrompt.text_encoder', AutoModel.from_pretrained("roberta-base")) self.text_encoder.requires_grad_(False) with deterministic_seed(): self.text_encoder.resize_token_embeddings(num_tokens) self.hidden_size = hidden_size self.n_head = n_head self.head_dim = hidden_size // n_head self.n_layer = n_layer self.n_block = n_block self.n_prefix_rec = n_prefix_rec self.n_prefix_conv = n_prefix_conv entity_hidden_size = hidden_size // 2 self.kg_encoder = RGCNConv(entity_hidden_size, entity_hidden_size, num_relations=num_relations, num_bases=num_bases) self.node_embeds = nn.Parameter(torch.empty(num_entities, entity_hidden_size)) stdv = math.sqrt(6.0 / (self.node_embeds.size(-2) + self.node_embeds.size(-1))) self.node_embeds.data.uniform_(-stdv, stdv) self.edge_index = nn.Parameter(edge_index, requires_grad=False) self.edge_type = nn.Parameter(edge_type, requires_grad=False) self.entity_proj1 = nn.Sequential( nn.Linear(entity_hidden_size, entity_hidden_size // 2), nn.ReLU(), nn.Linear(entity_hidden_size // 2, entity_hidden_size), ) self.entity_proj2 = nn.Linear(entity_hidden_size, hidden_size) self.token_proj1 = nn.Sequential( nn.Linear(token_hidden_size, token_hidden_size // 2), nn.ReLU(), nn.Linear(token_hidden_size // 2, token_hidden_size), ) self.token_proj2 = nn.Linear(token_hidden_size, hidden_size) self.cross_attn = nn.Linear(hidden_size, hidden_size, bias=False) self.prompt_proj1 = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Linear(hidden_size // 2, hidden_size), ) self.prompt_proj2 = nn.Linear(hidden_size, n_layer * n_block * hidden_size) if self.n_prefix_rec is not None: self.rec_prefix_embeds = nn.Parameter(torch.empty(n_prefix_rec, hidden_size)) nn.init.normal_(self.rec_prefix_embeds) self.rec_prefix_proj = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Linear(hidden_size // 2, hidden_size) ) if self.n_prefix_conv is not None: self.conv_prefix_embeds = nn.Parameter(torch.empty(n_prefix_conv, hidden_size)) nn.init.normal_(self.conv_prefix_embeds) self.conv_prefix_proj = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Linear(hidden_size // 2, hidden_size) )
def set_and_fix_node_embed(self, node_embeds: torch.Tensor): self.node_embeds.data = node_embeds self.node_embeds.requires_grad_(False) def get_entity_embeds(self): node_embeds = self.node_embeds entity_embeds = self.kg_encoder(node_embeds, self.edge_index, self.edge_type) + node_embeds entity_embeds = self.entity_proj1(entity_embeds) + entity_embeds entity_embeds = self.entity_proj2(entity_embeds) return entity_embeds
[docs] def forward(self, entity_ids=None, prompt=None, rec_mode=False, use_rec_prefix=False, use_conv_prefix=False): token_embeds = self.text_encoder(**prompt).last_hidden_state batch_size, entity_embeds, entity_len, token_len = None, None, None, None if entity_ids is not None: batch_size, entity_len = entity_ids.shape[:2] _entity_embeds = self.get_entity_embeds() entity_embeds = _entity_embeds[entity_ids] # (batch_size, entity_len, hidden_size) if token_embeds is not None: batch_size, token_len = token_embeds.shape[:2] token_embeds = self.token_proj1(token_embeds) + token_embeds # (batch_size, token_len, hidden_size) token_embeds = self.token_proj2(token_embeds) if entity_embeds is not None and token_embeds is not None: attn_weights = self.cross_attn(token_embeds) @ entity_embeds.permute(0, 2, 1) # (batch_size, token_len, entity_len) attn_weights /= self.hidden_size if rec_mode: token_weights = F.softmax(attn_weights, dim=1).permute(0, 2, 1) prompt_embeds = token_weights @ token_embeds + entity_embeds prompt_len = entity_len else: entity_weights = F.softmax(attn_weights, dim=2) prompt_embeds = entity_weights @ entity_embeds + token_embeds prompt_len = token_len elif entity_embeds is not None: prompt_embeds = entity_embeds prompt_len = entity_len else: prompt_embeds = token_embeds prompt_len = token_len if self.n_prefix_rec is not None and use_rec_prefix: prefix_embeds = self.rec_prefix_proj(self.rec_prefix_embeds) + self.rec_prefix_embeds prefix_embeds = prefix_embeds.expand(prompt_embeds.shape[0], -1, -1) prompt_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1) prompt_len += self.n_prefix_rec if self.n_prefix_conv is not None and use_conv_prefix: prefix_embeds = self.conv_prefix_proj(self.conv_prefix_embeds) + self.conv_prefix_embeds prefix_embeds = prefix_embeds.expand(prompt_embeds.shape[0], -1, -1) prompt_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1) prompt_len += self.n_prefix_conv prompt_embeds = self.prompt_proj1(prompt_embeds) + prompt_embeds prompt_embeds = self.prompt_proj2(prompt_embeds) prompt_embeds = prompt_embeds.reshape( batch_size, prompt_len, self.n_layer, self.n_block, self.n_head, self.head_dim ).permute(2, 3, 0, 4, 1, 5) # (n_layer, n_block, batch_size, n_head, prompt_len, head_dim) if rec_mode: return prompt_embeds, _entity_embeds else: return prompt_embeds