Source code for recwizard.pipelines.chatgpt.modeling_chatgpt_agent

import re
from recwizard.model_utils import BasePipeline
from .configuration_chatgpt_agent import ChatgptAgentConfig
import sys
sys.path.append('./src')
import openai


[docs]class ChatgptAgent(BasePipeline): """ The CRS model based on OpenAI's GPT models. Attributes: model (function): The GPT model. prompt(str): The prompt for the GPT model. movie_pattern (str): The pattern for the potential answers. answer_type (str): The type of the answer. """ config_class = ChatgptAgentConfig
[docs] def __init__(self, config, prompt=None, model_name=None, temperature=1, **kwargs): """ Initializes the instance of this CRS model. Args: config (ChatgptAgentConfig): The config file. prompt (str, optional): A prompt to override the prompt from config file. model_name (str, optional): The specified GPT model's name. """ super().__init__(config, **kwargs) self.movie_pattern = re.compile(config.rec_pattern) self.answer_type = config.answer_type self.model_name = config.model_name if model_name is None else model_name self.prompt = config.prompt if prompt is None else prompt self.temperature = temperature
[docs] def response(self, query, **kwargs): """ Response to the user's input. Args: query (str): The user's input. Returns: str: The processed user's input str: The response to the user's input. """ gen_output = self.gen_module.response(query, tokenizer=self.gen_tokenizer) resp_start = gen_output.rfind(self.config.resp_prompt) context, resp = gen_output[:resp_start], gen_output[resp_start:] n_movies = len(self.movie_pattern.findall(resp)) rec = self.rec_module.response(gen_output, topk=n_movies, tokenizer=self.rec_tokenizer) pure_output = gen_output.split('System: ')[1] input = ('Please replace the "<{}>" in the text "{}" by the following phrases "{}" respectively. ' .format(self.answer_type, pure_output, rec)) input += self.prompt messages = [ { "role": "user", "content": input } ] resp = openai.ChatCompletion.create( model=self.model_name, messages=messages, temperature=self.temperature )['choices'][0]['message']['content'] return resp