import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from recwizard import BaseModule
from recwizard.modules.monitor import monitor
from recwizard.utility import sort_for_packed_sequence, DeviceManager
from .beam_search import BeamSearch, Beam, get_best_beam
from .configuration_redial_gen import RedialGenConfig
from .tokenizer_redial_gen import RedialGenTokenizer
from .hrnn import HRNN
[docs]class RedialGen(BaseModule):
"""
Conditioned GRU. The context vector is used as an initial hidden state at each layer of the GRU
"""
config_class = RedialGenConfig
tokenizer_class = RedialGenTokenizer
LOAD_SAVE_IGNORES = {"encoder.base_encoder", }
[docs] def __init__(self, config: RedialGenConfig, word_embedding=None, **kwargs):
super().__init__(config, **kwargs)
self.encoder = HRNN(**config.hrnn_params)
word_embedding = nn.Embedding.from_pretrained(self.prepare_weight(word_embedding, 'decoder.embedding'))
DeviceManager.device = self.device
self.decoder = SwitchingDecoder(word_embedding=word_embedding, **config.decoder_params).to(self.device)
self.n_movies = config.n_movies
[docs] def forward(self, dialogue, lengths, recs, context=None, state=None, encode_only=False, **hrnn_input):
"""
Args:
dialogue: (batch_size, seq_len)
lengths: (batch_size)
recs: (batch_size, n_movies)
context:
state:
encode_only:
**hrnn_input:
Returns:
"""
batch_size, max_conv_length = dialogue.shape[:2]
# only generate new context when context is None
conv_repr = context if context is not None else self.encoder(**hrnn_input)['conv_repr'] # (batch, hidden_size)
if encode_only:
return conv_repr
sorted_utterances, sorted_lengths, conv_repr, recs, rev, num_positive_lengths \
= self.prepare_input_for_decoder(
dialogue, lengths, conv_repr, recs
)
# Run decoder
outputs = self.decoder(
sorted_utterances,
sorted_lengths,
conv_repr,
recs,
log_probabilities=True,
sample_movies=False
)
max_utterance_length = outputs.shape[1]
# Complete the missing sequences (of length 0)
if num_positive_lengths < batch_size * max_conv_length:
pad_tensor = torch.zeros(
batch_size * max_conv_length - num_positive_lengths,
max_utterance_length,
outputs.shape[-1],
device=self.device
)
outputs = torch.cat((outputs, pad_tensor), dim=0)
# retrieve original order
outputs = outputs.index_select(0, rev). \
view(batch_size, max_conv_length, max_utterance_length, -1)
# (batch, max_conv_len, max_sentence_len, vocab + n_movie)
return outputs
@monitor
def response(self, raw_input, tokenizer, recs=None, return_dict=False, beam_size=10, forbid_movies=None,
temperature=1, max_seq_length=40, **kwargs):
gen_input = DeviceManager.copy_to_device(tokenizer([raw_input]).data, device=self.device)
context = self.forward(**gen_input, recs=recs, encode_only=True)[0]
if recs is None:
recs = torch.zeros(1, self.n_movies, device=self.device)
mentioned_movies = forbid_movies if forbid_movies is not None else set()
inputs = {
"initial_sequence": [],
"forbid_movies": mentioned_movies,
"beam_size": beam_size,
"max_seq_length": max_seq_length,
"movie_recommendations": recs[-1].unsqueeze(0),
"context": context[-1].unsqueeze(0),
"sample_movies": False,
"temperature": temperature
}
best_beam = self.decoder.generate(**inputs, tokenizer=tokenizer)
mentioned_movies.update(best_beam.mentioned_movies)
output = tokenizer.decode(best_beam.sequence, skip_special_tokens=True)
if return_dict:
return {
"output": output,
"movieIds": list(mentioned_movies)
# These movieIds could be passed as forbid_movies in the following queries
}
return output
def prepare_input_for_decoder(self, dialogue, lengths, conv_repr, recs):
batch_size, max_conv_length = dialogue.shape[:2]
utterances = dialogue.view(batch_size * max_conv_length, -1)
# order by descending utterance length
lengths = lengths.reshape((-1))
sorted_lengths, sorted_idx, rev = sort_for_packed_sequence(lengths)
sorted_utterances = utterances.index_select(0, sorted_idx)
# shift the context vectors one step in time
pad_tensor = torch.zeros(
batch_size, 1, int(self.config.hrnn_params['conversation_encoder_hidden_size']), device=self.device)
conv_repr = torch.cat((pad_tensor, conv_repr), 1).narrow(
1, 0, max_conv_length)
# and reshape+reorder the same way as utterances
conv_repr = conv_repr.contiguous().view(
batch_size * max_conv_length, self.config.hrnn_params['conversation_encoder_hidden_size']) \
.index_select(0, sorted_idx)
# shift the movie recommendations one step in time
pad_tensor = torch.zeros(batch_size, 1, self.n_movies, device=self.device)
recs = torch.cat((pad_tensor, recs), 1).narrow(
1, 0, max_conv_length)
# and reshape+reorder movie_recommendations the same way as utterances
recs = recs.contiguous().view(
batch_size * max_conv_length, -1).index_select(0, sorted_idx)
# consider only lengths > 0
num_positive_lengths = torch.sum(lengths > 0)
sorted_utterances = sorted_utterances[:num_positive_lengths]
sorted_lengths = sorted_lengths[:num_positive_lengths]
conv_repr = conv_repr[:num_positive_lengths]
recs = recs[:num_positive_lengths]
return sorted_utterances, sorted_lengths, conv_repr, recs, rev, num_positive_lengths
[docs]class DecoderGRU(nn.Module):
"""
Conditioned GRU. The context vector is used as an initial hidden state at each layer of the GRU
"""
[docs] def __init__(self, hidden_size, context_size, num_layers, word_embedding, peephole):
super().__init__()
self.hidden_size = hidden_size
self.context_size = context_size
self.num_layers = num_layers
# peephole: concatenate the context to the input at every time step
self.peephole = peephole
if not peephole and context_size != hidden_size:
raise ValueError("peephole=False: the context size {} must match the hidden size {} in DecoderGRU".format(
context_size, hidden_size))
self.embedding = word_embedding
self.gru = nn.GRU(
input_size=self.embedding.embedding_dim + context_size * self.peephole,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True
)
[docs] def set_pretrained_embeddings(self, embedding_matrix):
"""Set embedding weights."""
self.embedding.weight.data.copy_(torch.from_numpy(embedding_matrix))
[docs] def forward(self, input_sequence, lengths, context=None, state=None):
"""
If not peephole, use the context vector as initial hidden state at each layer.
If peephole, concatenate context to embeddings at each time step instead.
If context is not provided, assume that a state is given (for generation)
Args:
input_sequence: (batch_size, seq_len)
lengths: (batch_size)
context: (batch, hidden_size) vector on which to condition
state: (batch, num_layers, hidden_size) gru state
Returns:
ouptut predictions (batch_size, seq_len, hidden_size) [, h_n (batch, num_layers, hidden_size)]
"""
embedded = self.embedding(input_sequence)
if context is not None:
batch_size, context_size = context.data.shape
seq_len = input_sequence.data.shape[1]
if self.peephole:
context_for_input = context.unsqueeze(1).expand(batch_size, seq_len, context_size)
embedded = torch.cat((embedded, context_for_input), dim=2)
packed = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True)
if not self.peephole:
# No peephole. Use context as initial hidden state
# expand to the number of layers in the decoder
context = context.unsqueeze(0).expand(
self.num_layers, batch_size, self.hidden_size).contiguous()
output, _ = self.gru(packed, context)
else:
output, _ = self.gru(packed)
return pad_packed_sequence(output, batch_first=True)[0]
elif state is not None:
output, h_n = self.gru(embedded, state)
return output, h_n
else:
raise ValueError("Must provide at least state or context")
[docs]class SwitchingDecoder(nn.Module):
"""
Decoder that takes the recommendations into account.
A switch choses whether to output a movie or a word
"""
[docs] def __init__(self,
hidden_size,
context_size,
num_layers,
peephole,
word_embedding=None,
):
super().__init__()
self.hidden_size = hidden_size
self.context_size = context_size
self.num_layers = num_layers
self.peephole = peephole
self.vocab_size = word_embedding.num_embeddings
self.decoder = DecoderGRU(
hidden_size=self.hidden_size,
context_size=self.context_size,
num_layers=self.num_layers,
word_embedding=word_embedding,
peephole=self.peephole
)
self.language_out = nn.Linear(in_features=self.hidden_size, out_features=self.vocab_size)
self.switch = nn.Linear(in_features=self.hidden_size + self.context_size, out_features=1)
# if self.cuda_available:
# self.cuda()
[docs] def set_pretrained_embeddings(self, embedding_matrix):
"""Set embedding weights."""
self.decoder.set_pretrained_embeddings(embedding_matrix)
[docs] def forward(self, input, lengths, context, movie_recommendations, log_probabilities, sample_movies,
forbid_movies=None, temperature=1):
"""
Args:
input: (batch, max_utterance_length)
log_probabilities:
temperature:
lengths:
context: (batch, hidden_size)
movie_recommendations: (batch, n_movies) the movie recommendations that condition the utterances.
Not necessarily in [0,1] range
sample_movies: (for generation) If true, sample a movie for each utterance, returning one-hot vectors
forbid_movies: (for generation) If provided, specifies movies that cannot be sampled
Returns:
[log] probabilities (batch, max_utterance_length, vocab + n_movies)
"""
batch_size, max_utterance_length = input.data.shape[:2]
# Run language decoder
# (batch, seq_len, hidden_size)
hidden = self.decoder(input, lengths, context=context)
# (batch, seq_len, vocab_size)
language_output = self.language_out(hidden)
# used in sampling
max_probabilities, _ = torch.max(F.softmax(movie_recommendations, dim=1), dim=1)
# expand context and movie_recommendations to each time step
context = context.unsqueeze(1).expand(
batch_size, max_utterance_length, self.context_size).contiguous()
n_movies = movie_recommendations.data.shape[1]
movie_recommendations = movie_recommendations.unsqueeze(1).expand(
batch_size, max_utterance_length, n_movies).contiguous()
# Compute Switch
# (batch, seq_len, 2 * hidden_size)
switch_input = torch.cat((context, hidden), dim=2)
switch = self.switch(switch_input)
# For generation: sample movies
# (batch, seq_len, vocab_size + n_movies)
if sample_movies:
# Prepare probability vector for sampling
movie_probabilities = F.softmax(movie_recommendations.view(-1, n_movies), dim=1)
# zero out the forbidden movies
if forbid_movies is not None:
if batch_size > 1:
raise ValueError("forbid_movies functionality only implemented with batch_size=1 for now")
for movieId in forbid_movies:
movie_probabilities[:, movieId] = 0
# Sample movies
sampled_movies = movie_probabilities.multinomial(1).view(
batch_size, max_utterance_length).data.cpu().numpy()
# Fill a new recommendations vector with sampled movies
movie_recommendations = torch.zeros(batch_size, max_utterance_length, n_movies, device=DeviceManager.device)
for i in range(batch_size):
for j in range(max_utterance_length):
# compensate bias towards sampled movie by putting the maximum probability of a movie instead of 1
movie_recommendations[i, j, sampled_movies[i, j]] = max_probabilities[i]
if log_probabilities:
raise ValueError("Sample movies only works with log_probabilities=False for now.")
# output = torch.cat((
# F.logsigmoid(switch) + F.log_softmax(language_output / temperature, dim=2),
# F.logsigmoid(-switch) + torch.log(movie_recommendations)
# ), dim=2)
else:
output = torch.cat((
switch.sigmoid() * F.softmax(language_output / temperature, dim=2),
(-switch).sigmoid() * movie_recommendations
), dim=2)
return output
if log_probabilities:
output = torch.cat((
F.logsigmoid(switch) + F.log_softmax(language_output / temperature, dim=2),
F.logsigmoid(-switch) + F.log_softmax(movie_recommendations / temperature, dim=2)
), dim=2)
else:
output = torch.cat((
switch.sigmoid() * F.softmax(language_output / temperature, dim=2),
(-switch).sigmoid() * F.softmax(movie_recommendations / temperature, dim=2)
), dim=2)
return output
[docs] def replace_movie_with_words(self, tokens, tokenizer):
"""
If the ID corresponds to a movie, returns the sequence of tokens that correspond to this movie name
Args:
tokens: list of token ids
tokenizer: tokenizer used to encode the movie names
Returns: modified sequence
"""
res = []
for token_id in tokens:
if token_id <= self.vocab_size:
res.append(token_id)
else:
movie_name = tokenizer.decode([token_id])
res.extend(tokenizer.tokenizers[1].encode(movie_name, add_special_tokens=False))
return res
[docs] def generate(self, initial_sequence=None, tokenizer=None, beam_size=10, max_seq_length=50, temperature=1,
forbid_movies=None, **kwargs):
"""
Beam search sentence generation
Args:
initial_sequence: list giving the initial sequence of tokens
kwargs: additional parameters to pass to model forward pass (e.g. a conditioning context)
Returns:
The best beam
"""
if initial_sequence is None:
initial_sequence = []
initial_sequence = [tokenizer.bos_token_id] + initial_sequence
Beam.end_token = tokenizer.eos_token_id
beams = BeamSearch.initial_beams(initial_sequence)
for i in range(max_seq_length):
# compute probabilities for each beam
probabilities = []
for beam in beams:
# add batch_dimension
model_input = torch.tensor(beam.sequence, device=DeviceManager.device, dtype=torch.long).unsqueeze(0)
beam_forbidden_movies = forbid_movies.union(beam.mentioned_movies)
prob = self.forward(
input=model_input,
lengths=torch.tensor([len(beam.sequence)], device=DeviceManager.device),
log_probabilities=False,
forbid_movies=beam_forbidden_movies,
temperature=temperature,
**kwargs
)
# get probabilities for the next token to generate
probabilities.append(prob[0, -1, :].cpu())
# update beams
beams = BeamSearch.update_beams(beams, beam_size, probabilities)
# replace movie names with the corresponding words
for beam in beams:
if beam.sequence[-1] > self.vocab_size:
# update the list of movies mentioned for preventing repeated recommendations
beam.mentioned_movies.add(beam.sequence[-1] - self.vocab_size)
beam.sequence[-1:] = self.replace_movie_with_words(beam.sequence[-1:], tokenizer)
return get_best_beam(beams)