import math
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import RGCNConv
from transformers import AutoModel
from recwizard.utility import deterministic_seed, Singleton
# Adapted from https://github.com/RUCAIBox/UniCRS/blob/main/src/model_prompt.py
[docs]class KGPrompt(nn.Module):
[docs] def __init__(
self, hidden_size, token_hidden_size, n_head, n_layer, n_block,
num_bases, kg_info, edge_index, edge_type, num_tokens,
n_prefix_rec=None, n_prefix_conv=None
):
super().__init__()
num_entities = kg_info['num_entities']
num_relations = kg_info['num_relations']
self.text_encoder = Singleton('unicrs.KGPrompt.text_encoder', AutoModel.from_pretrained("roberta-base"))
self.text_encoder.requires_grad_(False)
with deterministic_seed():
self.text_encoder.resize_token_embeddings(num_tokens)
self.hidden_size = hidden_size
self.n_head = n_head
self.head_dim = hidden_size // n_head
self.n_layer = n_layer
self.n_block = n_block
self.n_prefix_rec = n_prefix_rec
self.n_prefix_conv = n_prefix_conv
entity_hidden_size = hidden_size // 2
self.kg_encoder = RGCNConv(entity_hidden_size, entity_hidden_size, num_relations=num_relations,
num_bases=num_bases)
self.node_embeds = nn.Parameter(torch.empty(num_entities, entity_hidden_size))
stdv = math.sqrt(6.0 / (self.node_embeds.size(-2) + self.node_embeds.size(-1)))
self.node_embeds.data.uniform_(-stdv, stdv)
self.edge_index = nn.Parameter(edge_index, requires_grad=False)
self.edge_type = nn.Parameter(edge_type, requires_grad=False)
self.entity_proj1 = nn.Sequential(
nn.Linear(entity_hidden_size, entity_hidden_size // 2),
nn.ReLU(),
nn.Linear(entity_hidden_size // 2, entity_hidden_size),
)
self.entity_proj2 = nn.Linear(entity_hidden_size, hidden_size)
self.token_proj1 = nn.Sequential(
nn.Linear(token_hidden_size, token_hidden_size // 2),
nn.ReLU(),
nn.Linear(token_hidden_size // 2, token_hidden_size),
)
self.token_proj2 = nn.Linear(token_hidden_size, hidden_size)
self.cross_attn = nn.Linear(hidden_size, hidden_size, bias=False)
self.prompt_proj1 = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(hidden_size // 2, hidden_size),
)
self.prompt_proj2 = nn.Linear(hidden_size, n_layer * n_block * hidden_size)
if self.n_prefix_rec is not None:
self.rec_prefix_embeds = nn.Parameter(torch.empty(n_prefix_rec, hidden_size))
nn.init.normal_(self.rec_prefix_embeds)
self.rec_prefix_proj = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(hidden_size // 2, hidden_size)
)
if self.n_prefix_conv is not None:
self.conv_prefix_embeds = nn.Parameter(torch.empty(n_prefix_conv, hidden_size))
nn.init.normal_(self.conv_prefix_embeds)
self.conv_prefix_proj = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(hidden_size // 2, hidden_size)
)
def set_and_fix_node_embed(self, node_embeds: torch.Tensor):
self.node_embeds.data = node_embeds
self.node_embeds.requires_grad_(False)
def get_entity_embeds(self):
node_embeds = self.node_embeds
entity_embeds = self.kg_encoder(node_embeds, self.edge_index, self.edge_type) + node_embeds
entity_embeds = self.entity_proj1(entity_embeds) + entity_embeds
entity_embeds = self.entity_proj2(entity_embeds)
return entity_embeds
[docs] def forward(self, entity_ids=None, prompt=None, rec_mode=False, use_rec_prefix=False,
use_conv_prefix=False):
token_embeds = self.text_encoder(**prompt).last_hidden_state
batch_size, entity_embeds, entity_len, token_len = None, None, None, None
if entity_ids is not None:
batch_size, entity_len = entity_ids.shape[:2]
_entity_embeds = self.get_entity_embeds()
entity_embeds = _entity_embeds[entity_ids] # (batch_size, entity_len, hidden_size)
if token_embeds is not None:
batch_size, token_len = token_embeds.shape[:2]
token_embeds = self.token_proj1(token_embeds) + token_embeds # (batch_size, token_len, hidden_size)
token_embeds = self.token_proj2(token_embeds)
if entity_embeds is not None and token_embeds is not None:
attn_weights = self.cross_attn(token_embeds) @ entity_embeds.permute(0, 2,
1) # (batch_size, token_len, entity_len)
attn_weights /= self.hidden_size
if rec_mode:
token_weights = F.softmax(attn_weights, dim=1).permute(0, 2, 1)
prompt_embeds = token_weights @ token_embeds + entity_embeds
prompt_len = entity_len
else:
entity_weights = F.softmax(attn_weights, dim=2)
prompt_embeds = entity_weights @ entity_embeds + token_embeds
prompt_len = token_len
elif entity_embeds is not None:
prompt_embeds = entity_embeds
prompt_len = entity_len
else:
prompt_embeds = token_embeds
prompt_len = token_len
if self.n_prefix_rec is not None and use_rec_prefix:
prefix_embeds = self.rec_prefix_proj(self.rec_prefix_embeds) + self.rec_prefix_embeds
prefix_embeds = prefix_embeds.expand(prompt_embeds.shape[0], -1, -1)
prompt_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1)
prompt_len += self.n_prefix_rec
if self.n_prefix_conv is not None and use_conv_prefix:
prefix_embeds = self.conv_prefix_proj(self.conv_prefix_embeds) + self.conv_prefix_embeds
prefix_embeds = prefix_embeds.expand(prompt_embeds.shape[0], -1, -1)
prompt_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1)
prompt_len += self.n_prefix_conv
prompt_embeds = self.prompt_proj1(prompt_embeds) + prompt_embeds
prompt_embeds = self.prompt_proj2(prompt_embeds)
prompt_embeds = prompt_embeds.reshape(
batch_size, prompt_len, self.n_layer, self.n_block, self.n_head, self.head_dim
).permute(2, 3, 0, 4, 1, 5) # (n_layer, n_block, batch_size, n_head, prompt_len, head_dim)
if rec_mode:
return prompt_embeds, _entity_embeds
else:
return prompt_embeds