Source code for recwizard.modules.kbrd.transformer_encoder_decoder

""" This file is adapted from the KBRD original implementation: https://github.com/THUDM/KBRD
"""

import math

import torch
import torch.nn as nn

import torch.nn.functional as F

import math
import numpy as np


[docs]class TorchGeneratorModel(nn.Module): """ This Interface expects you to implement model with the following reqs: Attributes: `START`: LongTensor representing the start of sentence `END`: LongTensor representing the end of sentence `NULL_IDX`: index of null token `model.encoder`: takes input returns tuple (enc_out, enc_hidden, attn_mask) `model.decoder`: takes decoder params and returns decoder outputs after attn `model.output`: takes decoder outputs and returns distr over dictionary """
[docs] def __init__( self, padding_idx=0, start_idx=1, end_idx=2, unknown_idx=3, input_dropout=0, longest_label=1, ): super().__init__() self.NULL_IDX = padding_idx self.END_IDX = end_idx self.register_buffer("START", torch.LongTensor([start_idx])) self.longest_label = longest_label
[docs] def _starts(self, bsz): """Return bsz start tokens.""" return self.START.detach().expand(bsz, 1)
[docs] def decode_greedy(self, encoder_states, bsz, maxlen): """ Greedy search Args: encoder_states: output of the encoder model bsz: batch size maxlen: max number of tokens to decode Return: pair (logits, choices) of the greedy decode, shapes: (batch_size, max_len, #vocab), (batch_size, max_len) """ xs = self._starts(bsz) incr_state = None logits = [] for i in range(maxlen): # todo, break early if all beams saw EOS scores, incr_state = self.decoder(xs, encoder_states, incr_state) scores = scores[:, -1:, :] scores = self.output(scores) _, preds = scores.max(dim=-1) logits.append(scores) xs = torch.cat([xs, preds], dim=1) # check if everyone has generated an end token all_finished = ((xs == self.END_IDX).sum(dim=1) > 0).sum().item() == bsz if all_finished: break logits = torch.cat(logits, 1) return logits, xs
[docs] def decode_forced(self, encoder_states, ys): """ Decode with a fixed, true sequence, computing loss. Useful for training, or ranking fixed candidates. Args: encoder_states: output of the encoder model, shape: (batch_size, seq_len, hidden_size) ys: target tokens, shape: (batch_size, tgt_len) Return: pair (logits, choices) containing the logits and MLE predictions, shapes: (batch_size, tgt_len, #vocab), (batch_size, tgt_len) """ bsz = ys.size(0) seqlen = ys.size(1) inputs = ys.narrow(1, 0, seqlen - 1) inputs = torch.cat([self._starts(bsz), inputs], 1) latent, _ = self.decoder(inputs, encoder_states) logits = self.output(latent) _, preds = logits.max(dim=2) return logits, preds
[docs] def reorder_encoder_states(self, encoder_states, indices): """ Reorder encoder states according to a new set of indices. This is an abstract method, and *must* be implemented by the user. Its purpose is to provide beam search with a model-agnostic interface for beam search. For example, this method is used to sort hypotheses, expand beams, etc. For example, assume that encoder_states is an bsz x 1 tensor of values ```python indices = [0, 2, 2] encoder_states = [[0.1] [0.2] [0.3]] ``` then the output will be ```python output = [[0.1] [0.3] [0.3]] ``` Args: encoder_states: output from encoder. type is model specific. indices (List[int]): the indices to select over. The user must support non-tensor inputs. Return: The re-ordered encoder states. It should be of the same type as encoder states, and it must be a valid input to the decoder. """ raise NotImplementedError( "reorder_encoder_states must be implemented by the model" )
[docs] def reorder_decoder_incremental_state(self, incremental_state, inds): """ Reorder incremental state for the decoder. Used to expand selected beams in beam_search. Unlike reorder_encoder_states, implementing this method is optional. However, without incremental decoding, decoding a single beam becomes O(n^2) instead of O(n), which can make beam search impractically slow. In order to fall back to non-incremental decoding, just return None from this method. Args: incremental_state: second output of model.decoder inds (torch.LongTensor): indices to select and reorder over. Returns: The re-ordered decoder incremental states. It should be the same type as incremental_state, and usable as an input to the decoder. This method should return None if the model does not support incremental decoding. """ raise NotImplementedError( "reorder_decoder_incremental_state must be implemented by model" )
[docs] def forward( self, *xs, ys=None, cand_params=None, prev_enc=None, maxlen=None, bsz=None ): """ Get output predictions from the model. Args: xs (torch.LongTensor): input to the encoder, shapes: (batch_size, seq_len) ys (torch.LongTensor): expected output from the decoder, shapes: (batch_size, seq_len) cand_params (torch.FloatTensor): parameters for candidate generation, shape: (batch_size, num_cands, num_params) prev_enc (torch.FloatTensor): if you know you'll pass in the same xs multiple times, you can pass in the encoder output from the last forward pass to skip recalcuating the same encoder output. maxlen (int): max number of tokens to decode. if not set, will use the length of the longest label this model has seen. ignored when ys is not None. bsz (int): if ys is not provided, then you must specify the bsz for greedy decoding. Returns: (scores, candidate_scores, encoder_states) tuple - scores contains the model's predicted token scores. (FloatTensor[batch_size, seq_len, num_features]) - candidate_scores are the score the model assigned to each candidate. (FloatTensor[batch_size, num_cands]) - encoder_states are the output of model.encoder. Model specific types. Feed this back in to skip encoding on the next call. """ if ys is not None: # TODO: get rid of longest_label # keep track of longest label we've ever seen # we'll never produce longer ones than that during prediction self.longest_label = max(self.longest_label, ys.size(1)) # use cached encoding if available encoder_states = prev_enc if prev_enc is not None else self.encoder(*xs) if ys is not None: # use teacher forcing scores, preds = self.decode_forced(encoder_states, ys) else: scores, preds = self.decode_greedy( encoder_states, bsz, maxlen or self.longest_label ) return scores, preds, encoder_states
def _normalize(tensor, norm_layer): """Broadcast layer norm""" size = tensor.size() return norm_layer(tensor.view(-1, size[-1])).view(size) def _create_embeddings(vocab_size, embedding_size, padding_idx): """Create and initialize word embeddings.""" e = nn.Embedding(vocab_size, embedding_size, padding_idx) nn.init.normal_(e.weight, mean=0, std=embedding_size**-0.5) nn.init.constant_(e.weight[padding_idx], 0) return e def _build_encoder(config, embedding=None, padding_idx=None, reduction=True): return TransformerEncoder( n_heads=config.n_heads, n_layers=config.n_layers, embedding_size=config.gen_dim, ffn_size=config.ffn_size, vocabulary_size=config.vocab_size, embedding=embedding, dropout=config.dropout, attention_dropout=config.attention_dropout, relu_dropout=config.relu_dropout, padding_idx=padding_idx, learn_positional_embeddings=config.learn_positional_embeddings, embeddings_scale=config.embeddings_scale, reduction=reduction, n_positions=config.n_positions, ) def _build_decoder(config, embedding=None, padding_idx=None): return TransformerDecoder( n_heads=config.n_heads, n_layers=config.n_layers, embedding_size=config.gen_dim, ffn_size=config.ffn_size, vocabulary_size=config.vocab_size, embedding=embedding, dropout=config.dropout, attention_dropout=config.attention_dropout, relu_dropout=config.relu_dropout, padding_idx=padding_idx, learn_positional_embeddings=config.learn_positional_embeddings, embeddings_scale=config.embeddings_scale, n_positions=config.n_positions, ) def create_position_codes(n_pos, dim, out): position_enc = np.array( [ [pos / np.power(10000, 2 * j / dim) for j in range(dim // 2)] for pos in range(n_pos) ] ) out.detach_() out.requires_grad = False out[:, 0::2] = torch.FloatTensor(np.sin(position_enc)).type_as(out) out[:, 1::2] = torch.FloatTensor(np.cos(position_enc)).type_as(out) class TransformerResponseWrapper(nn.Module): """Transformer response rapper. Pushes input through transformer and MLP""" def __init__(self, transformer, hdim): super(TransformerResponseWrapper, self).__init__() dim = transformer.out_dim self.transformer = transformer self.mlp = nn.Sequential(nn.Linear(dim, hdim), nn.ReLU(), nn.Linear(hdim, dim)) def forward(self, *args): return self.mlp(self.transformer(*args))
[docs]class TransformerEncoder(nn.Module):
[docs] def __init__( self, n_heads, n_layers, embedding_size, ffn_size, vocabulary_size, embedding=None, dropout=0.0, attention_dropout=0.0, relu_dropout=0.0, padding_idx=0, learn_positional_embeddings=False, embeddings_scale=False, reduction=True, n_positions=1024, ): """Transformer encoder module. Args: n_heads (int): the number of multihead attention heads. n_layers (int): number of transformer layers. embedding_size (int): the embedding sizes. Must be a multiple of n_heads. ffn_size (int): the size of the hidden layer in the FFN vocabulary_size (int): size of the vocabulary embedding (nn.Embedding): an embedding matrix for the bottom layer of the transformer. If none, one is created for this encoder. dropout (float): Dropout used around embeddings and before layer layer normalizations. This is used in Vaswani 2017 and works well on large datasets. attention_dropout (float): Dropout performed after the multhead attention softmax. This is not used in Vaswani 2017. relu_dropout (float): Dropout used after the ReLU in the FFN. Not used in Vaswani 2017, but used in Tensor2Tensor. embeddings_scale (bool): Scale embeddings relative to their dimensionality. Found useful in fairseq. learn_positional_embeddings (bool): If off, sinusoidal embeddings are used. If on, position embeddings are learned from scratch. padding_idx (int): Reserved padding index in the embeddings matrix. n_positions (int): Size of the position embeddings matrix. """ super(TransformerEncoder, self).__init__() self.embedding_size = embedding_size self.ffn_size = ffn_size self.n_layers = n_layers self.n_heads = n_heads self.dim = embedding_size self.embeddings_scale = embeddings_scale self.reduction = reduction self.padding_idx = padding_idx # this is --dropout, not --relu-dropout or --attention-dropout self.dropout = nn.Dropout(p=dropout) self.out_dim = embedding_size assert ( embedding_size % n_heads == 0 ), "Transformer embedding size must be a multiple of n_heads" # check input formats: if embedding is not None: assert ( embedding_size is None or embedding_size == embedding.weight.shape[1] ), "Embedding dim must match the embedding size." if embedding is not None: self.embeddings = embedding else: assert False assert padding_idx is not None self.embeddings = nn.Embedding( vocabulary_size, embedding_size, padding_idx=padding_idx ) nn.init.normal_(self.embeddings.weight, 0, embedding_size**-0.5) # create the positional embeddings self.position_embeddings = nn.Embedding(n_positions, embedding_size) if not learn_positional_embeddings: create_position_codes( n_positions, embedding_size, out=self.position_embeddings.weight ) else: nn.init.normal_(self.position_embeddings.weight, 0, embedding_size**-0.5) # build the model self.layers = nn.ModuleList() for _ in range(self.n_layers): self.layers.append( TransformerEncoderLayer( n_heads, embedding_size, ffn_size, attention_dropout=attention_dropout, relu_dropout=relu_dropout, dropout=dropout, ) )
[docs] def forward(self, input): """ input data is a FloatTensor of shape [batch, seq_len, dim] mask is a ByteTensor of shape [batch, seq_len], filled with 1 when inside the sequence and 0 outside. """ mask = input != self.padding_idx positions = (mask.cumsum(dim=1, dtype=torch.int64) - 1).clamp_(min=0) tensor = self.embeddings(input) if self.embeddings_scale: tensor = tensor * np.sqrt(self.dim) tensor = tensor + self.position_embeddings(positions).expand_as(tensor) # --dropout on the embeddings tensor = self.dropout(tensor) tensor *= mask.unsqueeze(-1).type_as(tensor) for i in range(self.n_layers): tensor = self.layers[i](tensor, mask) if self.reduction: divisor = mask.type_as(tensor).sum(dim=1).unsqueeze(-1).clamp(min=1e-7) output = tensor.sum(dim=1) / divisor return output else: output = tensor return output, mask
[docs]class TransformerEncoderLayer(nn.Module):
[docs] def __init__( self, n_heads, embedding_size, ffn_size, attention_dropout=0.0, relu_dropout=0.0, dropout=0.0, ): super().__init__() self.dim = embedding_size self.ffn_dim = ffn_size self.attention = MultiHeadAttention( n_heads, embedding_size, dropout=attention_dropout, # --attention-dropout ) self.norm1 = nn.LayerNorm(embedding_size) self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout) self.norm2 = nn.LayerNorm(embedding_size) self.dropout = nn.Dropout(p=dropout)
[docs] def forward(self, tensor, mask): tensor = tensor + self.dropout(self.attention(tensor, mask=mask)) tensor = _normalize(tensor, self.norm1) tensor = tensor + self.dropout(self.ffn(tensor)) tensor = _normalize(tensor, self.norm2) tensor *= mask.unsqueeze(-1).type_as(tensor) return tensor
[docs]class TransformerDecoder(nn.Module):
[docs] def __init__( self, n_heads, n_layers, embedding_size, ffn_size, vocabulary_size, embedding=None, dropout=0.0, attention_dropout=0.0, relu_dropout=0.0, embeddings_scale=True, learn_positional_embeddings=False, padding_idx=None, n_positions=1024, ): """Transformer Decoder layer. Args: n_heads (int): the number of multihead attention heads. n_layers (int): number of transformer layers. embedding_size (int): the embedding sizes. Must be a multiple of n_heads. ffn_size (int): the size of the hidden layer in the FFN vocabulary_size (int): size of the vocabulary embedding (nn.Embedding): an embedding matrix for the bottom layer of the transformer. If none, one is created for this encoder. dropout (float): Dropout used around embeddings and before layer layer normalizations. This is used in Vaswani 2017 and works well on large datasets. attention_dropout (float): Dropout performed after the multhead attention softmax. This is not used in Vaswani 2017. relu_dropout (float): Dropout used after the ReLU in the FFN. Not used in Vaswani 2017, but used in Tensor2Tensor. embeddings_scale (bool): Scale embeddings relative to their dimensionality. Found useful in fairseq. learn_positional_embeddings (bool): If off, sinusoidal embeddings are used. If on, position embeddings are learned from scratch. padding_idx (int): Reserved padding index in the embeddings matrix. n_positions (int): Size of the position embeddings matrix. """ super().__init__() self.embedding_size = embedding_size self.ffn_size = ffn_size self.n_layers = n_layers self.n_heads = n_heads self.dim = embedding_size self.embeddings_scale = embeddings_scale self.dropout = nn.Dropout(p=dropout) # --dropout self.out_dim = embedding_size assert ( embedding_size % n_heads == 0 ), "Transformer embedding size must be a multiple of n_heads" self.embeddings = embedding # create the positional embeddings self.position_embeddings = nn.Embedding(n_positions, embedding_size) if not learn_positional_embeddings: create_position_codes( n_positions, embedding_size, out=self.position_embeddings.weight ) else: nn.init.normal_(self.position_embeddings.weight, 0, embedding_size**-0.5) # build the model self.layers = nn.ModuleList() for _ in range(self.n_layers): self.layers.append( TransformerDecoderLayer( n_heads, embedding_size, ffn_size, attention_dropout=attention_dropout, relu_dropout=relu_dropout, dropout=dropout, ) )
[docs] def forward(self, input, encoder_state, incr_state=None): encoder_output, encoder_mask = encoder_state seq_len = input.size(1) positions = input.new(seq_len).long() positions = torch.arange(seq_len, out=positions).unsqueeze(0) tensor = self.embeddings(input) if self.embeddings_scale: tensor = tensor * np.sqrt(self.dim) tensor = tensor + self.position_embeddings(positions).expand_as(tensor) tensor = self.dropout(tensor) # --dropout for layer in self.layers: tensor = layer(tensor, encoder_output, encoder_mask) return tensor, None
[docs]class TransformerDecoderLayer(nn.Module):
[docs] def __init__( self, n_heads, embedding_size, ffn_size, attention_dropout=0.0, relu_dropout=0.0, dropout=0.0, ): super().__init__() self.dim = embedding_size self.ffn_dim = ffn_size self.dropout = nn.Dropout(p=dropout) self.self_attention = MultiHeadAttention( n_heads, embedding_size, dropout=attention_dropout ) self.norm1 = nn.LayerNorm(embedding_size) self.encoder_attention = MultiHeadAttention( n_heads, embedding_size, dropout=attention_dropout ) self.norm2 = nn.LayerNorm(embedding_size) self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout) self.norm3 = nn.LayerNorm(embedding_size)
[docs] def forward(self, x, encoder_output, encoder_mask): decoder_mask = self._create_selfattn_mask(x) # first self attn residual = x # don't peak into the future! x = self.self_attention(query=x, mask=decoder_mask) x = self.dropout(x) # --dropout x = x + residual x = _normalize(x, self.norm1) residual = x x = self.encoder_attention( query=x, key=encoder_output, value=encoder_output, mask=encoder_mask ) x = self.dropout(x) # --dropout x = residual + x x = _normalize(x, self.norm2) # finally the ffn residual = x x = self.ffn(x) x = self.dropout(x) # --dropout x = residual + x x = _normalize(x, self.norm3) return x
def _create_selfattn_mask(self, x): # figure out how many timestamps we need bsz = x.size(0) time = x.size(1) # make sure that we don't look into the future mask = torch.tril(x.new(time, time).fill_(1)) # broadcast across batch mask = mask.unsqueeze(0).expand(bsz, -1, -1) return mask
[docs]class TransformerGeneratorModel(TorchGeneratorModel):
[docs] def __init__(self, config, kbrd_rec): self.pad_idx = config.pad_idx self.start_idx = config.start_idx self.end_idx = config.end_idx super().__init__(self.pad_idx, self.start_idx, self.end_idx) self.embeddings = _create_embeddings( config.vocab_size, config.gen_dim, self.pad_idx ) self.encoder = _build_encoder( config, self.embeddings, self.pad_idx, reduction=False, ) self.decoder = _build_decoder( config, self.embeddings, self.pad_idx, ) self.user_representation_to_bias_1 = nn.Linear(config.rec_dim, 512) self.user_representation_to_bias_2 = nn.Linear(512, config.vocab_size) self.kbrd = kbrd_rec for param in self.kbrd.parameters(): param.requires_grad = False
[docs] def reorder_encoder_states(self, encoder_states, indices): enc, mask = encoder_states if not torch.is_tensor(indices): indices = torch.LongTensor(indices).to(enc.device) enc = torch.index_select(enc, 0, indices) mask = torch.index_select(mask, 0, indices) return enc, mask
def reorder_decoder_incremental_state(self, incremental_state, inds): # no support for incremental decoding at this time return None def output(self, tensor): # project back to vocabulary output = F.linear(tensor, self.embeddings.weight) if hasattr(self, "user_representation"): up_bias = self.user_representation_to_bias_2( F.relu(self.user_representation_to_bias_1(self.user_representation)) ) # Expand to the whole sequence up_bias = up_bias.unsqueeze(dim=1) output += up_bias return output
[docs]class BasicAttention(nn.Module):
[docs] def __init__(self, dim=1, attn="cosine"): super().__init__() self.softmax = nn.Softmax(dim=dim) if attn == "cosine": self.cosine = nn.CosineSimilarity(dim=dim) self.attn = attn self.dim = dim
[docs] def forward(self, xs, ys): if self.attn == "cosine": l1 = self.cosine(xs, ys).unsqueeze(self.dim - 1) else: l1 = torch.bmm(xs, ys.transpose(1, 2)) if self.attn == "sqrt": d_k = ys.size(-1) l1 = l1 / math.sqrt(d_k) l2 = self.softmax(l1) lhs_emb = torch.bmm(l2, ys) # add back the query lhs_emb = lhs_emb.add(xs) return lhs_emb.squeeze(self.dim - 1), l2
[docs]class MultiHeadAttention(nn.Module):
[docs] def __init__(self, n_heads, dim, dropout=0): super(MultiHeadAttention, self).__init__() self.n_heads = n_heads self.dim = dim self.attn_dropout = nn.Dropout(p=dropout) # --attention-dropout self.q_lin = nn.Linear(dim, dim) self.k_lin = nn.Linear(dim, dim) self.v_lin = nn.Linear(dim, dim) # TODO: merge for the initialization step nn.init.xavier_normal_(self.q_lin.weight) nn.init.xavier_normal_(self.k_lin.weight) nn.init.xavier_normal_(self.v_lin.weight) # and set biases to 0 self.out_lin = nn.Linear(dim, dim) nn.init.xavier_normal_(self.out_lin.weight)
[docs] def forward(self, query, key=None, value=None, mask=None): # Input is [B, query_len, dim] # Mask is [B, key_len] (selfattn) or [B, key_len, key_len] (enc attn) batch_size, query_len, dim = query.size() assert ( dim == self.dim ), f"Dimensions do not match: {dim} query vs {self.dim} configured" assert mask is not None, "Mask is None, please specify a mask" n_heads = self.n_heads dim_per_head = dim // n_heads scale = math.sqrt(dim_per_head) def prepare_head(tensor): # input is [batch_size, seq_len, n_heads * dim_per_head] # output is [batch_size * n_heads, seq_len, dim_per_head] bsz, seq_len, _ = tensor.size() tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head) tensor = ( tensor.transpose(1, 2) .contiguous() .view(batch_size * n_heads, seq_len, dim_per_head) ) return tensor # q, k, v are the transformed values if key is None and value is None: # self attention key = value = query elif value is None: # key and value are the same, but query differs # self attention value = key _, key_len, dim = key.size() q = prepare_head(self.q_lin(query)) k = prepare_head(self.k_lin(key)) v = prepare_head(self.v_lin(value)) dot_prod = q.div_(scale).bmm(k.transpose(1, 2)) # [B * n_heads, query_len, key_len] attn_mask = ( (mask == 0) .view(batch_size, 1, -1, key_len) .repeat(1, n_heads, 1, 1) .expand(batch_size, n_heads, query_len, key_len) .view(batch_size * n_heads, query_len, key_len) ) assert attn_mask.shape == dot_prod.shape dot_prod.masked_fill_( attn_mask, -torch.finfo(dot_prod.dtype).max ) # TODO: previous implementation is neginf, check if they are equivalent attn_weights = F.softmax(dot_prod, dim=-1).type_as(query) attn_weights = self.attn_dropout(attn_weights) # --attention-dropout attentioned = attn_weights.bmm(v) attentioned = ( attentioned.type_as(query) .view(batch_size, n_heads, query_len, dim_per_head) .transpose(1, 2) .contiguous() .view(batch_size, query_len, dim) ) out = self.out_lin(attentioned) return out
[docs]class TransformerFFN(nn.Module):
[docs] def __init__(self, dim, dim_hidden, relu_dropout=0): super(TransformerFFN, self).__init__() self.relu_dropout = nn.Dropout(p=relu_dropout) self.lin1 = nn.Linear(dim, dim_hidden) self.lin2 = nn.Linear(dim_hidden, dim) nn.init.xavier_uniform_(self.lin1.weight) nn.init.xavier_uniform_(self.lin2.weight)
[docs] def forward(self, x): x = F.relu(self.lin1(x)) x = self.relu_dropout(x) # --relu-dropout x = self.lin2(x) return x