Source code for recwizard.pipelines.fill_blank.modeling_fill_blank

import re

from recwizard.model_utils import BasePipeline
from recwizard.modules.monitor import monitor
from recwizard.utility import SEP_TOKEN, EntityLink

from .configuration_fill_blank import FillBlankConfig


[docs]class FillBlankPipeline(BasePipeline): config_class = FillBlankConfig def __init__(self, config, use_resp=True, **kwargs): super().__init__(config, **kwargs) self.movie_pattern = re.compile(config.rec_pattern) self.use_resp = use_resp self.entity_linker = EntityLink() @monitor def response(self, query, return_dict=False, gen_args=None, rec_args=None): gen_args = gen_args or {} rec_args = rec_args or {} # generate response template gen_output = self.gen_module.response( query, tokenizer=self.gen_tokenizer, return_dict=return_dict, **gen_args ) resp = gen_output["output"] if return_dict else gen_output # resp = 'System: <movie>, <movie>, and <movie> are some good options.' # generate topk recommendations k_movies = len(self.movie_pattern.findall(resp)) if k_movies > 0: rec_input = query if self.use_resp: rec_input += SEP_TOKEN + resp rec_output = self.rec_module.response( rec_input, topk=k_movies, tokenizer=self.rec_tokenizer, return_dict=return_dict, **rec_args, ) rec = rec_output["output"] if return_dict else rec_output # replace <movie> placeholder in response for i in range(k_movies): resp = self.movie_pattern.sub(rec[i], resp, count=1) else: rec_input = {} rec_output = {} if return_dict: if rec_output.get("links"): # special case for llm movieLinks = rec_output["links"] elif "movieIds" in rec_output: movieIds = rec_output.get("movieIds") movieNames = self.rec_tokenizer.batch_decode(movieIds) movieLinks = { movieName: self.entity_linker(movieName) for movieName in movieNames } else: movieLinks = {} return { "gen_input": query, "gen_output": gen_output, "rec_input": rec_input, "rec_output": rec_output, "output": resp, "links": movieLinks, } return resp