Source code for recwizard.pipelines.trivial.modeling_trivial

from recwizard.model_utils import BasePipeline
from recwizard.modules.monitor import monitor

from .configuration_trivial import TrivialConfig

[docs]class TrivialPipeline(BasePipeline): config_class = TrivialConfig
[docs] def forward(self, input_ids, attention_mask, labels=None, **kwargs): raise NotImplementedError
@monitor def response(self, query, return_dict=False, rec_args=None, gen_args=None, **kwargs): rec_args = rec_args or {} gen_args = gen_args or {} rec_output = self.rec_module.response(query, tokenizer=self.rec_tokenizer, return_dict=True, **rec_args) gen_output = self.gen_module.response(query, tokenizer=self.gen_tokenizer, return_dict=True, **gen_args) if return_dict: return { 'rec_logits': rec_output['logits'], 'gen_logits': gen_output['logits'], 'rec_output': rec_output['output'], 'gen_output': gen_output['output'] } return gen_output['output'][0] + "\n - " + "\n - ".join(rec_output['output'])