Source code for recwizard.modules.redial.hrnn

import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from transformers import AutoModel

from recwizard.utility import sort_for_packed_sequence


[docs]class HRNN(nn.Module): """ Hierarchical Recurrent Neural Network params.keys() ['use_gensen', 'use_movie_occurrences', 'sentence_encoder_hidden_size', 'conversation_encoder_hidden_size', 'sentence_encoder_num_layers', 'conversation_encoder_num_layers', 'use_dropout', ['embedding_dimension']] Input: Input["dialogue"] (batch, max_conv_length, max_utterance_length) Long Tensor Input["senders"] (batch, max_conv_length) Float Tensor Input["lengths"] (batch, max_conv_length) list (optional) Input["movie_occurrences"] (batch, max_conv_length, max_utterance_length) for word occurence (batch, max_conv_length) for sentence occurrence. Float Tensor """ # LOAD_SAVE_IGNORES = ('base_encoder',)
[docs] def __init__(self, sentence_encoder_model, sentence_encoder_hidden_size, sentence_encoder_num_layers, conversation_encoder_hidden_size, conversation_encoder_num_layers, use_movie_occurrences, conv_bidirectional=False, return_all=True, return_sentence_representations=False, use_dropout=False, ): super().__init__() self.conv_bidirectional = conv_bidirectional self.return_all = return_all self.return_sentence_representations = return_sentence_representations self.base_encoder = AutoModel.from_pretrained(sentence_encoder_model) self.sentence_encoder_hidden_size = sentence_encoder_hidden_size self.sentence_encoder = nn.GRU( self.base_encoder.config.hidden_size + (use_movie_occurrences == "word"), hidden_size=sentence_encoder_hidden_size, num_layers=sentence_encoder_num_layers, batch_first=True, bidirectional=True ) self.conversation_encoder = nn.GRU( input_size=2 * sentence_encoder_hidden_size + 1 + (use_movie_occurrences == "sentence"), # concatenation of 2 directions for sentence encoders + sender informations + movie occurences hidden_size=conversation_encoder_hidden_size, num_layers=conversation_encoder_num_layers, batch_first=True, bidirectional=conv_bidirectional ) self.use_movie_occurrences = use_movie_occurrences self.dropout = nn.Dropout(p=use_dropout) if use_dropout else None
def get_sentence_representations(self, input_ids, attention_mask, senders, movie_occurrences=None): lengths = attention_mask.sum(dim=-1) # mark the length of tokenized sentences batch_size, max_conversation_length = len(lengths), len(lengths[0]) # order by descending utterance length lengths = lengths.reshape((-1)) sorted_lengths, sorted_idx, rev = sort_for_packed_sequence(lengths) # reshape and reorder # # consider sequences of length > 0 only num_positive_lengths = torch.sum(lengths > 0) sorted_lengths = sorted_lengths[:num_positive_lengths] sorted_idx = sorted_idx[:num_positive_lengths] batch_encoding = { 'input_ids': input_ids.view(batch_size * max_conversation_length, -1).index_select(0, sorted_idx), 'attention_mask': attention_mask.view(batch_size * max_conversation_length, -1).index_select(0, sorted_idx) } with torch.no_grad(): embedded = self.base_encoder(**batch_encoding, output_hidden_states=True, return_dict=True).last_hidden_state # (< batch_size * max conversation_length, max_sentence_length, embedding_size/2048 for gensen) if self.dropout: embedded = self.dropout(embedded) if self.use_movie_occurrences == "word": if movie_occurrences is None: raise ValueError("Please specify movie occurrences") # reshape and reorder movie occurrences by utterance length movie_occurrences = movie_occurrences.view( batch_size * max_conversation_length, -1).index_select(0, sorted_idx) # keep indices where sequence_length > 0 embedded = torch.cat((embedded, movie_occurrences.unsqueeze(2)), 2) packed_sentences = pack_padded_sequence(embedded, sorted_lengths.cpu(), batch_first=True) # Apply encoder and get the final hidden states _, sentence_representations = self.sentence_encoder(packed_sentences) # (2*num_layers, < batch_size * max_conv_length, hidden_size) # Concat the hidden states of the last layer (two directions of the GRU) sentence_representations = torch.cat((sentence_representations[-1], sentence_representations[-2]), 1) # (num_positive_lengths , 2*hidden_size) if self.dropout: sentence_representations = self.dropout(sentence_representations) # Complete the missing sequences (of length 0) if num_positive_lengths < batch_size * max_conversation_length: pad_tensor = torch.zeros( batch_size * max_conversation_length - num_positive_lengths, 2 * self.sentence_encoder_hidden_size, device=sentence_representations.device ) sentence_representations = torch.cat(( sentence_representations, pad_tensor ), 0) # sentence_representations.data.shape = (batch_size * max_conversation_length, 2*hidden_size) # Retrieve original sentence order and Reshape to separate conversations sentence_representations = sentence_representations.index_select(0, rev).view( batch_size, max_conversation_length, 2 * self.sentence_encoder_hidden_size ) # Append sender information sentence_representations = torch.cat([sentence_representations, senders.unsqueeze(2)], 2) # Append movie occurrence information if required if self.use_movie_occurrences == "sentence": if movie_occurrences is None: raise ValueError("Please specify movie occurrences") sentence_representations = torch.cat( (sentence_representations, torch.sign(torch.sum(movie_occurrences, dim=-1, keepdim=True))), 2) # (batch_size, max_conv_length, 513 + self.params['use_movie_occurrences']) return sentence_representations
[docs] def forward(self, input_ids, attention_mask, senders, movie_occurrences, conversation_lengths, **kwargs): # get sentence representations sentence_representations = self.get_sentence_representations(input_ids, attention_mask, senders, movie_occurrences) # (batch_size, max_conv_length, 2*sent_hidden_size + 1 + use_movie_occurences) # Pass whole conversation into GRU sorted_lengths, sorted_idx, rev = sort_for_packed_sequence(conversation_lengths) # reorder in decreasing sequence length sorted_representations = sentence_representations.index_select(0, sorted_idx) packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths.cpu(), batch_first=True) conversation_representations, last_state = self.conversation_encoder(packed_sequences) # retrieve original order conversation_representations, _ = pad_packed_sequence(conversation_representations, batch_first=True) conversation_representations = conversation_representations.index_select(0, rev) # (num_layers * num_directions, batch, conv_hidden_size) last_state = last_state.index_select(1, rev) if self.dropout: conversation_representations = self.dropout(conversation_representations) last_state = self.dropout(last_state) res = {} if self.return_all: if not self.return_sentence_representations: # return the last layer of the GRU for each t. # (batch_size, max_conv_length, hidden_size*num_directions res["conv_repr"] = conversation_representations else: # also return sentence representations res["conv_repr"] = conversation_representations res["sent_repr"] = sentence_representations else: # get the last hidden state only if self.conv_bidirectional: # Concat the hidden states for the last layer (two directions of the GRU) last_state = torch.cat((last_state[-1], last_state[-2]), 1) # (batch_size, num_directions*hidden_size) res["conv_repr"] = last_state else: # Return the hidden state from the last layers res["conv_repr"] = last_state[-1] return res