Source code for recwizard.pipelines.switch_decode.modeling_switch_decode

import re

import torch
from torch import nn as nn
from torch.nn import functional as F
from recwizard.model_utils import BasePipeline
from .configuration_switch_decode import SwitchDecodeConfig
from recwizard.modules.redial import Beam, BeamSearch, get_best_beam
from ...utility import SEP_TOKEN


[docs]class SwitchDecodePipeline(BasePipeline): config_class = SwitchDecodeConfig
[docs] def __init__(self, config, temperature=1, sample_movies=False, **kwargs): super().__init__(config, **kwargs) self.switch = nn.Linear(in_features=self.config.hidden_size + self.config.context_size, out_features=1) self.vocab_size = len(self.gen_tokenizer) self.context_size = self.gen_module.decoder.context_size self.sample_movies = sample_movies self.t = temperature self.context = None self.recs = None
def switch_decode(self, switch_context, language_output, movie_recommendations, temperature, forbid_movies=None): batch_size, max_utterance_length = language_output.shape[:2] n_movies = movie_recommendations.data.shape[1] max_probabilities, _ = torch.max(F.softmax(movie_recommendations, dim=1), dim=1) movie_recommendations = movie_recommendations.unsqueeze(1).expand( batch_size, max_utterance_length, n_movies).contiguous() switch = self.switch(switch_context) # For generation: sample movies # (batch, seq_len, vocab_size + n_movies) if self.sample_movies: # Prepare probability vector for sampling movie_probabilities = F.softmax(movie_recommendations.view(-1, n_movies), dim=1) # zero out the forbidden movies if forbid_movies is not None: if batch_size > 1: raise ValueError("forbid_movies functionality only implemented with batch_size=1 for now") for movieId in forbid_movies: movie_probabilities[:, movieId] = 0 # Sample movies sampled_movies = movie_probabilities.multinomial(1).view( batch_size, max_utterance_length).data.cpu().numpy() # Fill a new recommendations vector with sampled movies movie_recommendations = torch.zeros(batch_size, max_utterance_length, n_movies, device=self.device) for i in range(batch_size): for j in range(max_utterance_length): # compensate bias towards sampled movie by putting the maximum probability of a movie instead of 1 movie_recommendations[i, j, sampled_movies[i, j]] = max_probabilities[i] output = torch.cat(( switch.sigmoid() * F.softmax(language_output / temperature, dim=-1), (-switch).sigmoid() * movie_recommendations ), dim=2) else: output = torch.cat(( F.logsigmoid(switch) + F.log_softmax(language_output / temperature, dim=-1), F.logsigmoid(-switch) + F.log_softmax(movie_recommendations / temperature, dim=-1) ), dim=2) return output
[docs] def forward(self, rec_inputs, gen_inputs, forbid_movies): batch_size, max_conv_length, max_utterance_length = gen_inputs['dialogue'].shape[:3] recs = self.recs if self.recs is not None else self.rec_module.forward(**rec_inputs) # (batch, seq_len, n_movies) pad_tensor = torch.zeros(batch_size, 1, recs.shape[-1], device=recs.device) recs = torch.cat((pad_tensor, recs), 1)[:, -2, :] context_, hidden, language_output = self.gen_module.forward(**gen_inputs, context=self.context) context = self.context if self.context is not None else context_ # merge batch_size and max_conv_length into one dimension for switch context = context.expand(batch_size*max_conv_length, max_utterance_length, self.context_size).contiguous() hidden = hidden.view(batch_size*max_conv_length, max_utterance_length, -1) switch_input = torch.cat((context, hidden), dim=-1) # FIXME: the conv_length dimension is not unrolled language_output = language_output.view(batch_size*max_conv_length, max_utterance_length, -1) output = self.switch_decode(switch_input, language_output, recs, temperature=self.t, forbid_movies=forbid_movies) return output.view(batch_size, max_conv_length, max_utterance_length, -1)
[docs] def replace_movie_with_words(self, tokens): """ If the ID corresponds to a movie, returns the sequence of tokens that correspond to this movie name Args: tokens: Returns: modified sequence """ res = [] for token_id in tokens: if token_id <= self.vocab_size: res.append(token_id) else: movie_name = self.rec_tokenizer.decode(token_id - self.vocab_size) res.extend(self.gen_tokenizer.encode(movie_name, add_special_tokens=False)) return res
[docs] def response(self, query: str, beam_size=10, forbid_movies=None, temperature=1, **kwargs): if forbid_movies is None: forbid_movies = set() self.t = temperature initial_sequence = [self.gen_tokenizer.tokenizers[1].bos_token_id] Beam.end_token = self.gen_tokenizer.tokenizers[1].eos_token_id beams = BeamSearch.initial_beams(initial_sequence) max_conv_length = len(re.findall(SEP_TOKEN, query)) + 1 conv_repr = self.gen_module.response(query, self.gen_tokenizer, encoder_only=True) # TODO: if first utterance is recommender, add a 0-context at the beginning self.context = conv_repr[:, -1:, :] self.recs = self.rec_module.response(query, self.rec_tokenizer) for i in range(self.config.max_seq_length): # compute probabilities for each beam probabilities = [] for beam in beams: # add batch_dimension gen_inputs = { 'dialogue': torch.tensor(beam.sequence, device=self.device, dtype=torch.long).unsqueeze(0).unsqueeze(0), 'lengths': torch.tensor([len(beam.sequence)]).unsqueeze(0), } beam_forbidden_movies = forbid_movies.union(beam.mentioned_movies) prob = self.forward( rec_inputs=None, gen_inputs=gen_inputs, forbid_movies=beam_forbidden_movies, ) # get probabilities for the next token to generate probabilities.append(prob[0, 0, -1, :].cpu()) # update beams beams = BeamSearch.update_beams(beams, beam_size, probabilities) # replace movie names with the corresponding words for beam in beams: if beam.sequence[-1] > self.vocab_size: # update the list of movies mentioned for preventing repeated recommendations beam.mentioned_movies.add(beam.sequence[-1] -self.vocab_size) beam.sequence[-1:] = self.replace_movie_with_words(beam.sequence[-1]) best_beam = get_best_beam(beams) return self.gen_tokenizer.batch_decode(best_beam.sequence)