Source code for ml_wrappers.model.openai_wrapper

# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

"""Defines a model wrapper for an openai model endpoint."""

import numpy as np

try:
    import openai
    openai_installed = True
except ImportError:
    openai_installed = False
try:
    from raiutils.common.retries import retry_function
    rai_utils_installed = True
except ImportError:
    rai_utils_installed = False


CONTENT = 'content'


def replace_backtick_chars(message):
    """Replace backtick characters in a message.

    :param message: The message.
    :type message: str
    :return: The message with backtick characters replaced.
    :rtype: str
    """
    return message.replace('`', '')


class ChatCompletion(object):
    """A class to call the openai chat completion endpoint."""

    def __init__(self, messages, engine, temperature,
                 max_tokens, top_p, frequency_penalty,
                 presence_penalty, stop):
        """Initialize the class.

        :param messages: The messages.
        :type messages: list
        :param engine: The engine.
        :type engine: str
        :param temperature: The temperature.
        :type temperature: float
        :param max_tokens: The maximum number of tokens.
        :type max_tokens: int
        :param top_p: The top p.
        :type top_p: float
        :param frequency_penalty: The frequency penalty.
        :type frequency_penalty: float
        :param presence_penalty: The presence penalty.
        :type presence_penalty: float
        :param stop: The stop.
        :type stop: list
        """
        self.messages = messages
        self.engine = engine
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.top_p = top_p
        self.frequency_penalty = frequency_penalty
        self.presence_penalty = presence_penalty
        self.stop = stop

    def fetch(self):
        """Call the openai chat completion endpoint.

        :return: The response.
        :rtype: dict
        """
        return openai.ChatCompletion.create(
            engine=self.engine,
            messages=self.messages,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            top_p=self.top_p,
            frequency_penalty=self.frequency_penalty,
            presence_penalty=self.presence_penalty,
            stop=self.stop)


[docs]class OpenaiWrapperModel(object): """A model wrapper for an openai model endpoint.""" def __init__(self, api_type, api_base, api_version, api_key, engine="gpt-4-32k", temperature=0.7, max_tokens=800, top_p=0.95, frequency_penalty=0, presence_penalty=0, stop=None): """Initialize the model. :param api_type: The type of the API. :type api_type: str :param api_base: The base URL for the API. :type api_base: str :param api_version: The version of the API. :type api_version: str :param api_key: The API key. :type api_key: str :param engine: The engine. :type engine: str :param temperature: The temperature. :type temperature: float :param max_tokens: The maximum number of tokens. :type max_tokens: int :param top_p: The top p. :type top_p: float :param frequency_penalty: The frequency penalty. :type frequency_penalty: float :param presence_penalty: The presence penalty. :type presence_penalty: float :param stop: The stop. :type stop: list """ self.api_type = api_type self.api_base = api_base self.api_version = api_version self.api_key = api_key self.engine = engine self.temperature = temperature self.max_tokens = max_tokens self.top_p = top_p self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty self.stop = stop def _call_webservice(self, data): """Common code to call the webservice. :param data: The data to send to the webservice. :type data: pandas.DataFrame or list :return: The result. :rtype: numpy.ndarray """ if not rai_utils_installed: error = "raiutils package is required to call openai endpoint" raise RuntimeError(error) if not openai_installed: error = "openai package is required to call openai endpoint" raise RuntimeError(error) if not isinstance(data, list): if isinstance(data, np.ndarray): data = data.tolist() else: data = data.values.tolist() openai.api_type = self.api_type openai.api_base = self.api_base openai.api_version = self.api_version openai.api_key = self.api_key answers = [] for doc in data: messages = [] messages.append({'role': 'user', CONTENT: doc}) fetcher = ChatCompletion(messages, self.engine, self.temperature, self.max_tokens, self.top_p, self.frequency_penalty, self.presence_penalty, self.stop) action_name = "Call openai chat completion" err_msg = "Failed to call openai endpoint" max_retries = 4 retry_delay = 60 response = retry_function(fetcher.fetch, action_name, err_msg, max_retries=max_retries, retry_delay=retry_delay) answers.append(replace_backtick_chars(response['choices'][0]['message'][CONTENT])) return np.array(answers)
[docs] def predict(self, context, model_input=None): """Predict using the model. :param context: The context for MLFlow model or the input data. :type context: mlflow.pyfunc.model.PythonModelContext or pandas.DataFrame :param model_input: The input to the model. :type model_input: pandas.DataFrame :return: The predictions. :rtype: numpy.ndarray """ # This is to conform to the scikit-learn API format # which MLFlow does not follow if model_input is None: model_input = context questions = model_input['questions'] if isinstance(questions, str): questions = [questions] result = self._call_webservice(questions) return result