Skip to content

Cloudflare workersai

CloudflareEmbedding #

Bases: BaseEmbedding

Cloudflare Workers AI class for generating text embeddings.

This class allows for the generation of text embeddings using Cloudflare Workers AI with the BAAI general embedding models.

Args: account_id (str): The Cloudflare Account ID. auth_token (str, Optional): The Cloudflare Auth Token. Alternatively, set up environment variable CLOUDFLARE_AUTH_TOKEN. model (str): The model ID for the embedding service. Cloudflare provides different models for embeddings, check https://developers.cloudflare.com/workers-ai/models/#text-embeddings. Defaults to "@cf/baai/bge-base-en-v1.5". embed_batch_size (int): The batch size for embedding generation. Cloudflare's current limit is 100 at max. Defaults to llama_index's default.

Note: Ensure you have a valid Cloudflare account and have access to the necessary AI services and models. The account ID and authorization token are sensitive details; secure them appropriately.

Source code in llama-index-integrations/embeddings/llama-index-embeddings-cloudflare-workersai/llama_index/embeddings/cloudflare_workersai/base.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class CloudflareEmbedding(BaseEmbedding):
    """
    Cloudflare Workers AI class for generating text embeddings.

    This class allows for the generation of text embeddings using Cloudflare Workers AI with the BAAI general embedding models.

    Args:
    account_id (str): The Cloudflare Account ID.
    auth_token (str, Optional): The Cloudflare Auth Token. Alternatively, set up environment variable `CLOUDFLARE_AUTH_TOKEN`.
    model (str): The model ID for the embedding service. Cloudflare provides different models for embeddings, check https://developers.cloudflare.com/workers-ai/models/#text-embeddings. Defaults to "@cf/baai/bge-base-en-v1.5".
    embed_batch_size (int): The batch size for embedding generation. Cloudflare's current limit is 100 at max. Defaults to llama_index's default.

    Note:
    Ensure you have a valid Cloudflare account and have access to the necessary AI services and models. The account ID and authorization token are sensitive details; secure them appropriately.
    """

    account_id: str = Field(default=None, description="The Cloudflare Account ID.")
    auth_token: str = Field(default=None, description="The Cloudflare Auth Token.")
    model: str = Field(
        default="@cf/baai/bge-base-en-v1.5",
        description="The model to use when calling Cloudflare AI API",
    )

    _session: Any = PrivateAttr()

    def __init__(
        self,
        account_id: str,
        auth_token: Optional[str] = None,
        model: str = "@cf/baai/bge-base-en-v1.5",
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        callback_manager: Optional[CallbackManager] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            embed_batch_size=embed_batch_size,
            callback_manager=callback_manager,
            model=model,
            **kwargs,
        )
        self.account_id = account_id
        self.auth_token = get_from_param_or_env(
            "auth_token", auth_token, "CLOUDFLARE_AUTH_TOKEN", ""
        )
        self.model = model
        self._session = requests.Session()
        self._session.headers.update({"Authorization": f"Bearer {self.auth_token}"})

    @classmethod
    def class_name(cls) -> str:
        return "CloudflareEmbedding"

    def _get_query_embedding(self, query: str) -> List[float]:
        """Get query embedding."""
        return self._get_text_embedding(query)

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """The asynchronous version of _get_query_embedding."""
        return await self._aget_text_embedding(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        """Get text embedding."""
        return self._get_text_embeddings([text])[0]

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """Asynchronously get text embedding."""
        result = await self._aget_text_embeddings([text])
        return result[0]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Get text embeddings."""
        response = self._session.post(
            API_URL_TEMPLATE.format(self.account_id, self.model), json={"text": texts}
        ).json()

        if "result" not in response:
            print(response)
            raise RuntimeError("Failed to fetch embeddings")

        return response["result"]["data"]

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Asynchronously get text embeddings."""
        import aiohttp

        async with aiohttp.ClientSession(trust_env=True) as session:
            headers = {
                "Authorization": f"Bearer {self.auth_token}",
                "Accept-Encoding": "identity",
            }
            async with session.post(
                API_URL_TEMPLATE.format(self.account_id, self.model),
                json={"text": texts},
                headers=headers,
            ) as response:
                resp = await response.json()
                if "result" not in resp:
                    raise RuntimeError("Failed to fetch embeddings asynchronously")

                return resp["result"]["data"]