Skip to content

Predibase

PredibaseLLM #

Bases: CustomLLM

Predibase LLM.

To use, you should have the predibase python package installed, and have your Predibase API key.

The model_name parameter is the Predibase "serverless" base_model ID (see https://docs.predibase.com/user-guide/inference/models for the catalog).

An optional adapter_id parameter is the Predibase ID or the HuggingFace ID of a fine-tuned LLM adapter, whose base model is the model parameter; the fine-tuned adapter must be compatible with its base model; otherwise, an error is raised. If the fine-tuned adapter is hosted at Predibase, adapter_version must be specified.

Examples:

pip install llama-index-llms-predibase

import os

os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"

from llama_index.llms.predibase import PredibaseLLM

llm = PredibaseLLM(
    model_name="mistral-7b",
    predibase_sdk_version=None,  # optional parameter (defaults to the latest Predibase SDK version if omitted)
    adapter_id="my-adapter-id",  # optional parameter
    adapter_version=3,  # optional parameter (applies to Predibase only)
    temperature=0.3,
    max_new_tokens=512,
)
response = llm.complete("Hello World!")
print(str(response))
Source code in llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
class PredibaseLLM(CustomLLM):
    """Predibase LLM.

    To use, you should have the ``predibase`` python package installed,
    and have your Predibase API key.

    The `model_name` parameter is the Predibase "serverless" base_model ID
    (see https://docs.predibase.com/user-guide/inference/models for the catalog).

    An optional `adapter_id` parameter is the Predibase ID or the HuggingFace ID
    of a fine-tuned LLM adapter, whose base model is the `model` parameter; the
    fine-tuned adapter must be compatible with its base model; otherwise, an
    error is raised.  If the fine-tuned adapter is hosted at Predibase,
    `adapter_version` must be specified.

    Examples:
        `pip install llama-index-llms-predibase`

        ```python
        import os

        os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"

        from llama_index.llms.predibase import PredibaseLLM

        llm = PredibaseLLM(
            model_name="mistral-7b",
            predibase_sdk_version=None,  # optional parameter (defaults to the latest Predibase SDK version if omitted)
            adapter_id="my-adapter-id",  # optional parameter
            adapter_version=3,  # optional parameter (applies to Predibase only)
            temperature=0.3,
            max_new_tokens=512,
        )
        response = llm.complete("Hello World!")
        print(str(response))
        ```
    """

    model_name: str = Field(description="The Predibase base model to use.")
    predibase_api_key: str = Field(description="The Predibase API key to use.")
    predibase_sdk_version: str = Field(
        default=None,
        description="The optional version (string) of the Predibase SDK (defaults to the latest if not specified).",
    )
    adapter_id: str = Field(
        default=None,
        description="The optional Predibase ID or HuggingFace ID of a fine-tuned adapter to use.",
    )
    adapter_version: str = Field(
        default=None,
        description="The optional version number of fine-tuned adapter use (applies to Predibase only).",
    )
    max_new_tokens: int = Field(
        default=DEFAULT_NUM_OUTPUTS,
        description="The number of tokens to generate.",
        gt=0,
    )
    temperature: float = Field(
        default=DEFAULT_TEMPERATURE,
        description="The temperature to use for sampling.",
        gte=0.0,
        lte=1.0,
    )
    context_window: int = Field(
        default=DEFAULT_CONTEXT_WINDOW,
        description="The number of context tokens available to the LLM.",
        gt=0,
    )

    _client: Any = PrivateAttr()

    def __init__(
        self,
        model_name: str,
        predibase_api_key: Optional[str] = None,
        predibase_sdk_version: Optional[str] = None,
        adapter_id: Optional[str] = None,
        adapter_version: Optional[int] = None,
        max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
        temperature: float = DEFAULT_TEMPERATURE,
        context_window: int = DEFAULT_CONTEXT_WINDOW,
        callback_manager: Optional[CallbackManager] = None,
        system_prompt: Optional[str] = None,
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
        completion_to_prompt: Optional[Callable[[str], str]] = None,
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
        output_parser: Optional[BaseOutputParser] = None,
    ) -> None:
        predibase_api_key = (
            predibase_api_key
            if predibase_api_key
            else os.environ.get("PREDIBASE_API_TOKEN")
        )
        assert predibase_api_key is not None

        super().__init__(
            model_name=model_name,
            predibase_api_key=predibase_api_key,
            predibase_sdk_version=predibase_sdk_version,
            adapter_id=adapter_id,
            adapter_version=adapter_version,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            context_window=context_window,
            callback_manager=callback_manager,
            system_prompt=system_prompt,
            messages_to_prompt=messages_to_prompt,
            completion_to_prompt=completion_to_prompt,
            pydantic_program_mode=pydantic_program_mode,
            output_parser=output_parser,
        )

        self._client: Union["PredibaseClient", "Predibase"] = self.initialize_client()

    def initialize_client(
        self,
    ) -> Union["PredibaseClient", "Predibase"]:
        try:
            if self._is_deprecated_sdk_version():
                from predibase import PredibaseClient
                from predibase.pql import get_session
                from predibase.pql.api import Session

                session: Session = get_session(
                    token=self.predibase_api_key,
                    gateway="https://api.app.predibase.com/v1",
                    serving_endpoint="serving.app.predibase.com",
                )
                return PredibaseClient(session=session)

            from predibase import Predibase

            os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com"
            return Predibase(api_token=self.predibase_api_key)
        except ValueError as e:
            raise ValueError("Your API key is not correct. Please try again") from e

    @classmethod
    def class_name(cls) -> str:
        return "PredibaseLLM"

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.max_new_tokens,
            model_name=self.model_name,
        )

    @llm_completion_callback()
    def complete(
        self, prompt: str, formatted: bool = False, **kwargs: Any
    ) -> "CompletionResponse":
        options: Dict[str, Union[str, float]] = copy.deepcopy(kwargs)
        options.update(
            {
                "max_new_tokens": self.max_new_tokens,
                "temperature": self.temperature,
            }
        )

        response_text: str

        if self._is_deprecated_sdk_version():
            from predibase.pql.api import ServerResponseError
            from predibase.resource.llm.interface import (
                HuggingFaceLLM,
                LLMDeployment,
            )
            from predibase.resource.llm.response import GeneratedResponse
            from predibase.resource.model import Model

            base_llm_deployment: LLMDeployment = self._client.LLM(
                uri=f"pb://deployments/{self.model_name}"
            )

            result: GeneratedResponse
            if self.adapter_id:
                """
                Attempt to retrieve the fine-tuned adapter from a Predibase repository.
                If absent, then load the fine-tuned adapter from a HuggingFace repository.
                """
                adapter_model: Union[Model, HuggingFaceLLM]
                try:
                    adapter_model = self._client.get_model(
                        name=self.adapter_id,
                        version=self.adapter_version,
                        model_id=None,
                    )
                except ServerResponseError:
                    # Predibase does not recognize the adapter ID (query HuggingFace).
                    adapter_model = self._client.LLM(uri=f"hf://{self.adapter_id}")
                result = base_llm_deployment.with_adapter(model=adapter_model).generate(
                    prompt=prompt,
                    options=options,
                )
            else:
                result = base_llm_deployment.generate(
                    prompt=prompt,
                    options=options,
                )
            response_text = result.response
        else:
            import requests
            from lorax.client import Client as LoraxClient
            from lorax.errors import GenerationError
            from lorax.types import Response

            lorax_client: LoraxClient = self._client.deployments.client(
                deployment_ref=self.model_name
            )

            response: Response
            if self.adapter_id:
                """
                Attempt to retrieve the fine-tuned adapter from a Predibase repository.
                If absent, then load the fine-tuned adapter from a HuggingFace repository.
                """
                if self.adapter_version:
                    # Since the adapter version is provided, query the Predibase repository.
                    pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}"
                    try:
                        response = lorax_client.generate(
                            prompt=prompt,
                            adapter_id=pb_adapter_id,
                            **options,
                        )
                    except GenerationError as ge:
                        raise ValueError(
                            f'An adapter with the ID "{pb_adapter_id}" cannot be found in the Predibase repository of fine-tuned adapters.'
                        ) from ge
                else:
                    # The adapter version is omitted, hence look for the adapter ID in the HuggingFace repository.
                    try:
                        response = lorax_client.generate(
                            prompt=prompt,
                            adapter_id=self.adapter_id,
                            adapter_source="hub",
                            **options,
                        )
                    except GenerationError as ge:
                        raise ValueError(
                            f"""Either an adapter with the ID "{self.adapter_id}" cannot be found in a HuggingFace repository, \
or it is incompatible with the base model (please make sure that the adapter configuration is consistent).
"""
                        ) from ge
            else:
                try:
                    response = lorax_client.generate(
                        prompt=prompt,
                        **options,
                    )
                except requests.JSONDecodeError as jde:
                    raise ValueError(
                        f"""An LLM with the deployment ID "{self.model_name}" cannot be found at Predibase \
(please refer to "https://docs.predibase.com/user-guide/inference/models" for the list of supported models).
"""
                    ) from jde
            response_text = response.generated_text

        return CompletionResponse(text=response_text)

    @llm_completion_callback()
    def stream_complete(
        self, prompt: str, formatted: bool = False, **kwargs: Any
    ) -> "CompletionResponseGen":
        raise NotImplementedError

    def _is_deprecated_sdk_version(self) -> bool:
        try:
            import semantic_version
            from semantic_version.base import Version

            from predibase.version import __version__ as current_version

            sdk_semver_deprecated: Version = semantic_version.Version(
                version_string="2024.4.8"
            )
            actual_current_version: str = self.predibase_sdk_version or current_version
            sdk_semver_current: Version = semantic_version.Version(
                version_string=actual_current_version
            )
            return not (
                (sdk_semver_current > sdk_semver_deprecated)
                or ("+dev" in actual_current_version)
            )
        except ImportError as e:
            raise ImportError(
                "Could not import Predibase Python package. "
                "Please install it with `pip install semantic_version predibase`."
            ) from e

metadata property #

metadata: LLMMetadata

Get LLM metadata.