Source code for recwizard.modules.redial.hrnn_for_classification

import torch
import torch.nn as nn
import torch.nn.functional as F

from recwizard.utility import pad_and_stack
from recwizard import BaseConfig, BaseModule
from .hrnn import HRNN

[docs]class HRNNForClassification(BaseModule):
[docs] def __init__(self, hrnn_params, output_classes, return_liked_probability=True, multiple_items_per_example=True): """ Args: return_liked_probability: multiple_items_per_example: should be set to True when each conversation corresponds an example (e.g. when generate output) Should be set to False in training because each item corresponds an example """ super().__init__(BaseConfig()) self.encoder = HRNN(return_all=return_liked_probability, **hrnn_params) conv_bidrectional = hrnn_params['conv_bidirectional'] encoder_output_size = hrnn_params['conversation_encoder_hidden_size'] self.linears = nn.ModuleDict( {output: nn.Linear((1 + conv_bidrectional) * encoder_output_size, num_classes) for output, num_classes in output_classes.items()}) self.return_liked_probability = return_liked_probability self.multiple_items_per_example = multiple_items_per_example
def on_pretrain_finished(self, requires_grad=False): self.requires_grad_(requires_grad) self.return_liked_probability = True self.multiple_items_per_example = True
[docs] def forward(self, input_ids, attention_mask, senders, movie_occurrences, conversation_lengths): """ Args: input_ids: attention_mask: senders: movie_occurrences: conversation_lengths: Returns: """ if self.multiple_items_per_example: batch_size = len(input_ids) # if we have mutiple items in one example, first flatten (expand) the examples to have only one item to classify per example indices = [(i, j) # i-th conversation in the batch, j-th item in the conversation for (i, conv_item_occurrences) in enumerate(movie_occurrences) for j in range(len(conv_item_occurrences)) ] batch_indices = torch.tensor([i[0] for i in indices], device=input_ids.device) # flatten item occurrences to shape (total_num_mentions_in_batch, max_conv_length, max_utt_length) movie_occurrences = pad_and_stack([oc for conv_oc in movie_occurrences for oc in conv_oc]) input_ids = input_ids.index_select(0, batch_indices) attention_mask = attention_mask.index_select(0, batch_indices) senders = senders.index_select(0, batch_indices) conversation_lengths = conversation_lengths.index_select(0, batch_indices) conv_repr = self.encoder( input_ids, attention_mask, senders, movie_occurrences, conversation_lengths, )['conv_repr'] if self.multiple_items_per_example: mask = movie_occurrences.sum(dim=2) > 0 mask = mask.cumsum(dim=1) > 0 # (total_num_mentions_in_batch, max_conv_length) conv_repr = conv_repr * mask.unsqueeze(-1) if self.return_liked_probability: # return the liked probability at each utterance in the dialogue # (batch_size, max_conv_length, hidden_size*num_directions) # return the probability of the "liked" class output = {"i_liked": F.softmax(self.linears["i_liked"](conv_repr), dim=-1)[:, :, 1]} # output = {"i_liked": F.softmax(self.Iliked(conv_repr), dim=-1)[:, :, 1]} else: output = {key: linear(conv_repr) for key, linear in self.linears.items()} if self.multiple_items_per_example: # recover the batch shape from flattend output batch_output = {key: [[] for _ in range(batch_size)] for key in output.keys()} for key in output.keys(): for idx, (i, j) in enumerate(indices): batch_output[key][i].append(output[key][idx]) output = batch_output return output
[docs]class RedialSentimentAnalysisLoss(nn.Module):
[docs] def __init__(self, class_weight, use_targets): super().__init__() self.class_weight = class_weight # string that specifies which targets to consider and, if specified, the weights self.use_targets = use_targets if len(use_targets.split()) == 6: self.weights = [int(x) for x in use_targets.split()[:3]] else: self.weights = [1, 1, 1] self.suggested_criterion = nn.BCEWithLogitsLoss() self.seen_criterion = nn.CrossEntropyLoss() if self.class_weight and "liked" in self.class_weight: self.liked_criterion = nn.CrossEntropyLoss(weight=torch.Tensor(self.class_weight["liked"])) else: self.liked_criterion = nn.CrossEntropyLoss()
# if torch.cuda.is_available(): # self.cuda()
[docs] def forward(self, output, target): loss = 0 if "suggested" in self.use_targets: loss += self.weights[0] * (self.suggested_criterion(output["i_suggested"].squeeze(1), target[:, 0].float()) + self.suggested_criterion(output["r_suggested"].squeeze(1), target[:, 3].float())) if "seen" in self.use_targets: loss += self.weights[1] * (self.seen_criterion(output["i_seen"], target[:, 1]) + self.seen_criterion(output["r_seen"], target[:, 4])) if "liked" in self.use_targets: loss += self.weights[2] * (self.liked_criterion(output["i_liked"], target[:, 2]) + self.liked_criterion(output["r_liked"], target[:, 5])) return loss