Source code for recwizard.modules.redial.autorec

import torch
import torch.nn as nn

from recwizard import BaseConfig, BaseModule
import logging

logger = logging.getLogger(__name__)


# Adapted from https://github.com/RaymondLi0/conversational-recommendations/blob/master/models/autorec.py

[docs]class AutoRec(BaseModule): # We use features from BaseModule to load previous checkpoints """ User-based Autoencoder for Collaborative Filtering """
[docs] def __init__(self, n_movies, layer_sizes, g, f): super().__init__(BaseConfig()) self.n_movies = n_movies self.layer_sizes = layer_sizes if g == 'identity': self.g = lambda x: x elif g == 'sigmoid': self.g = nn.Sigmoid() elif g == 'relu': self.g = nn.ReLU() else: raise ValueError("Got invalid function name for g : {}".format(self.config.g)) self.encoder = UserEncoder(layer_sizes=self.layer_sizes, n_movies=self.n_movies, f=f) self.user_representation_size = self.layer_sizes[-1] self.decoder = nn.Linear(in_features=self.user_representation_size, out_features=self.n_movies)
# if checkpoint is not None: # self.load_from_checkpoint(checkpoint) # if self.cuda_available: # self.cuda()
[docs] def load_checkpoint(self, checkpoint, verbose=True, strict=True, LOAD_PREFIX=''): if verbose: logger.info(f"AutoRec: loading model from {checkpoint}") # Load pretrained model, keeping only the first n_movies. (see merge_indexes function in match_movies) checkpoint = torch.load(checkpoint, map_location=self.device) if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] model_dict = self.state_dict() # load all weights except for the weights of the first layer and the decoder model_dict.update({k: v for k, v in checkpoint.items() if k.startswith( LOAD_PREFIX) and k != LOAD_PREFIX + "encoder.layers.0.weight" and "decoder" not in k}) # load first layer and decoder: assume the movies to keep are the n_movies first movies encoder0weight = checkpoint[LOAD_PREFIX + "encoder.layers.0.weight"][:, :self.n_movies] decoderweight = checkpoint[LOAD_PREFIX + "decoder.weight"][:self.n_movies, :] decoderbias = checkpoint[LOAD_PREFIX + "decoder.bias"][:self.n_movies] # If checkpoint has fewer movies than the model, append zeros (no actual recommendations for new movies) # (When using an updated movie list) if encoder0weight.shape[1] < self.n_movies: encoder0weight = torch.cat(( encoder0weight, torch.zeros(encoder0weight.shape[0], self.n_movies - encoder0weight.shape[1], device=self.device)), dim=1) decoderweight = torch.cat(( decoderweight, torch.zeros(self.n_movies - decoderweight.shape[0], decoderweight.shape[1], device=self.device)), dim=0) decoderbias = torch.cat(( decoderbias, torch.zeros(self.n_movies - decoderbias.shape[0], device=self.device)), dim=0) model_dict.update({ LOAD_PREFIX + "encoder.layers.0.weight": encoder0weight, LOAD_PREFIX + "decoder.weight": decoderweight, LOAD_PREFIX + "decoder.bias": decoderbias, }) self.load_state_dict(model_dict, strict=strict) if verbose: logger.info("AutoRec: finish loading")
[docs] def forward(self, input, additional_context=None, range01=True): """ :param input: (batch, n_movies) :param additional_context: potential information to add to user representation (batch, user_rep_size) :param range01: If true, apply sigmoid to the output :return: output recommendations (batch, n_movies) """ # get user representation encoded = self.encoder(input, raw_last_layer=True) # eventually use additional context if additional_context is not None: encoded = self.encoder.f(encoded + additional_context) else: encoded = self.encoder.f(encoded) # decode to get movie recommendations if range01: return self.g(self.decoder(encoded)) else: return self.decoder(encoded)
[docs]class UserEncoder(nn.Module):
[docs] def __init__(self, layer_sizes, n_movies, f): """ :param layer_sizes: list giving the size of each layer :param n_movies: :param f: """ super(UserEncoder, self).__init__() self.layers = nn.ModuleList([nn.Linear(in_features=n_movies, out_features=layer_sizes[0]) if i == 0 else nn.Linear(in_features=layer_sizes[i - 1], out_features=layer_sizes[i]) for i in range(len(layer_sizes))]) if f == 'identity': self.f = lambda x: x elif f == 'sigmoid': self.f = nn.Sigmoid() elif f == 'relu': self.f = nn.ReLU() else: raise ValueError("Got invalid function name for f : {}".format(f))
[docs] def forward(self, input, raw_last_layer=False): for (i, layer) in enumerate(self.layers): if raw_last_layer and i == len(self.layers) - 1: # do not apply activation to last layer input = layer(input) else: input = self.f(layer(input)) return input
[docs]class ReconstructionLoss(nn.Module):
[docs] def __init__(self): super(ReconstructionLoss, self).__init__() # Sum of losses self.mse_loss = nn.MSELoss(size_average=False) # Keep track of number of observer targets to normalize loss self.nb_observed_targets = 0
[docs] def forward(self, input, target): # only consider the observed targets mask = (target != -1) observed_input = torch.masked_select(input, mask) observed_target = torch.masked_select(target, mask) # increment nb of observed targets self.nb_observed_targets += len(observed_target) loss = self.mse_loss(observed_input, observed_target) return loss
[docs] def normalize_loss_reset(self, loss): """ returns loss divided by nb of observed targets, and reset nb_observed_targets to 0 :param loss: Total summed loss :return: mean loss """ if self.nb_observed_targets == 0: raise ValueError( "Nb observed targets was 0. Please evaluate some examples before calling normalize_loss_reset") n_loss = loss / self.nb_observed_targets self.nb_observed_targets = 0 return n_loss