Source code for langchain.embeddings.openai

"""Wrapper around OpenAI embedding models."""
from __future__ import annotations

import logging
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Set,
    Tuple,
    Union,
)

import numpy as np
from pydantic import BaseModel, Extra, root_validator
from tenacity import (
    before_sleep_log,
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env

logger = logging.getLogger(__name__)


def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]:
    import openai

    min_seconds = 4
    max_seconds = 10
    # Wait 2^x * 1 second between each retry starting with
    # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
    return retry(
        reraise=True,
        stop=stop_after_attempt(embeddings.max_retries),
        wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
        retry=(
            retry_if_exception_type(openai.error.Timeout)
            | retry_if_exception_type(openai.error.APIError)
            | retry_if_exception_type(openai.error.APIConnectionError)
            | retry_if_exception_type(openai.error.RateLimitError)
            | retry_if_exception_type(openai.error.ServiceUnavailableError)
        ),
        before_sleep=before_sleep_log(logger, logging.WARNING),
    )


def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
    """Use tenacity to retry the embedding call."""
    retry_decorator = _create_retry_decorator(embeddings)

    @retry_decorator
    def _embed_with_retry(**kwargs: Any) -> Any:
        return embeddings.client.create(**kwargs)

    return _embed_with_retry(**kwargs)


[docs]class OpenAIEmbeddings(BaseModel, Embeddings): """Wrapper around OpenAI embedding models. To use, you should have the ``openai`` python package installed, and the environment variable ``OPENAI_API_KEY`` set with your API key or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain.embeddings import OpenAIEmbeddings openai = OpenAIEmbeddings(openai_api_key="my-api-key") In order to use the library with Microsoft Azure endpoints, you need to set the OPENAI_API_TYPE, OPENAI_API_BASE, OPENAI_API_KEY and optionally and API_VERSION. The OPENAI_API_TYPE must be set to 'azure' and the others correspond to the properties of your endpoint. In addition, the deployment name must be passed as the model parameter. Example: .. code-block:: python import os os.environ["OPENAI_API_TYPE"] = "azure" os.environ["OPENAI_API_BASE"] = "https://<your-endpoint.openai.azure.com/" os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key" from langchain.embeddings.openai import OpenAIEmbeddings embeddings = OpenAIEmbeddings( deployment="your-embeddings-deployment-name", model="your-embeddings-model-name" ) text = "This is a test query." query_result = embeddings.embed_query(text) """ client: Any #: :meta private: model: str = "text-embedding-ada-002" deployment: str = model # to support Azure OpenAI Service custom deployment names embedding_ctx_length: int = 8191 openai_api_key: Optional[str] = None openai_organization: Optional[str] = None allowed_special: Union[Literal["all"], Set[str]] = set() disallowed_special: Union[Literal["all"], Set[str], Tuple[()]] = "all" chunk_size: int = 1000 """Maximum number of texts to embed in each batch""" max_retries: int = 6 """Maximum number of retries to make when generating.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" openai_api_key = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) openai_organization = get_from_dict_or_env( values, "openai_organization", "OPENAI_ORGANIZATION", default="", ) try: import openai openai.api_key = openai_api_key if openai_organization: openai.organization = openai_organization values["client"] = openai.Embedding except ImportError: raise ValueError( "Could not import openai python package. " "Please install it with `pip install openai`." ) return values # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb def _get_len_safe_embeddings( self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None ) -> List[List[float]]: embeddings: List[List[float]] = [[] for i in range(len(texts))] try: import tiktoken tokens = [] indices = [] encoding = tiktoken.model.encoding_for_model(self.model) for i, text in enumerate(texts): # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") token = encoding.encode( text, allowed_special=self.allowed_special, disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): tokens += [token[j : j + self.embedding_ctx_length]] indices += [i] batched_embeddings = [] _chunk_size = chunk_size or self.chunk_size for i in range(0, len(tokens), _chunk_size): response = embed_with_retry( self, input=tokens[i : i + _chunk_size], engine=self.deployment, ) batched_embeddings += [r["embedding"] for r in response["data"]] results: List[List[List[float]]] = [[] for _ in range(len(texts))] lens: List[List[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) lens[indices[i]].append(len(batched_embeddings[i])) for i in range(len(texts)): _result = results[i] if len(_result) == 0: average = embed_with_retry(self, input="", engine=self.deployment)[ "data" ][0]["embedding"] else: average = np.average(_result, axis=0, weights=lens[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings except ImportError: raise ValueError( "Could not import tiktoken python package. " "This is needed in order to for OpenAIEmbeddings. " "Please install it with `pip install tiktoken`." ) def _embedding_func(self, text: str, *, engine: str) -> List[float]: """Call out to OpenAI's embedding endpoint.""" # handle large input text if self.embedding_ctx_length > 0: return self._get_len_safe_embeddings([text], engine=engine)[0] else: # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") return embed_with_retry(self, input=[text], engine=engine)["data"][0][ "embedding" ]
[docs] def embed_documents( self, texts: List[str], chunk_size: Optional[int] = 0 ) -> List[List[float]]: """Call out to OpenAI's embedding endpoint for embedding search docs. Args: texts: The list of texts to embed. chunk_size: The chunk size of embeddings. If None, will use the chunk size specified by the class. Returns: List of embeddings, one for each text. """ # handle batches of large input text if self.embedding_ctx_length > 0: return self._get_len_safe_embeddings(texts, engine=self.deployment) else: results = [] _chunk_size = chunk_size or self.chunk_size for i in range(0, len(texts), _chunk_size): response = embed_with_retry( self, input=texts[i : i + _chunk_size], engine=self.deployment, ) results += [r["embedding"] for r in response["data"]] return results
[docs] def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text. Args: text: The text to embed. Returns: Embedding for the text. """ embedding = self._embedding_func(text, engine=self.deployment) return embedding