Source code for recwizard.pipelines.expansion.modeling_expansion
from recwizard.model_utils import BasePipeline
from recwizard.modules.monitor import monitor
from recwizard.utility import EntityLink
from .configuration_expansion import ExpansionConfig
[docs]class ExpansionPipeline(BasePipeline):
config_class = ExpansionConfig
[docs] def __init__(self, config, use_rec_logits=True, **kwargs):
super().__init__(config, **kwargs)
self.use_rec_logits = use_rec_logits
self.entity_linker = EntityLink()
@monitor
def response(self, query, return_dict=False, rec_args=None, gen_args=None, **kwargs):
rec_args = rec_args or {}
gen_args = gen_args or {}
rec_output = self.rec_module.response(query,
tokenizer=self.rec_tokenizer,
return_dict=True,
**rec_args)
recs = rec_output['logits'] if self.use_rec_logits else rec_output['output']
gen_output = self.gen_module.response(query,
tokenizer=self.gen_tokenizer,
recs=recs,
return_dict=return_dict,
**gen_args)
if return_dict:
movieIds = gen_output.get('movieIds') or rec_output.get('movieIds')
movieNames = self.rec_tokenizer.batch_decode(movieIds)
movieLinks = {movieName: self.entity_linker(movieName) for movieName in movieNames}
return {
'rec_output': rec_output,
'gen_output': gen_output,
'output': gen_output['output'],
'links': movieLinks
}
return gen_output
[docs] def forward(self, **input):
recs = self.rec_module.forward(**input)
return self.gen_module.forward(**input, recs=recs)