Skip to content

Predibase

PredibaseLLM #

Bases: CustomLLM

Predibase LLM.

Examples:

pip install llama-index-llms-predibase

import os

os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"

from llama_index.llms.predibase import PredibaseLLM

llm = PredibaseLLM(
    model_name="llama-2-13b", temperature=0.3, max_new_tokens=512
)
response = llm.complete("Hello World!")
print(str(response))
Source code in llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class PredibaseLLM(CustomLLM):
    """Predibase LLM.

    Examples:
        `pip install llama-index-llms-predibase`

        ```python
        import os

        os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"

        from llama_index.llms.predibase import PredibaseLLM

        llm = PredibaseLLM(
            model_name="llama-2-13b", temperature=0.3, max_new_tokens=512
        )
        response = llm.complete("Hello World!")
        print(str(response))
        ```
    """

    model_name: str = Field(description="The Predibase model to use.")
    predibase_api_key: str = Field(description="The Predibase API key to use.")
    max_new_tokens: int = Field(
        default=DEFAULT_NUM_OUTPUTS,
        description="The number of tokens to generate.",
        gt=0,
    )
    temperature: float = Field(
        default=DEFAULT_TEMPERATURE,
        description="The temperature to use for sampling.",
        gte=0.0,
        lte=1.0,
    )
    context_window: int = Field(
        default=DEFAULT_CONTEXT_WINDOW,
        description="The number of context tokens available to the LLM.",
        gt=0,
    )

    _client: Any = PrivateAttr()

    def __init__(
        self,
        model_name: str,
        predibase_api_key: Optional[str] = None,
        max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
        temperature: float = DEFAULT_TEMPERATURE,
        context_window: int = DEFAULT_CONTEXT_WINDOW,
        callback_manager: Optional[CallbackManager] = None,
        system_prompt: Optional[str] = None,
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
        completion_to_prompt: Optional[Callable[[str], str]] = None,
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
        output_parser: Optional[BaseOutputParser] = None,
    ) -> None:
        predibase_api_key = (
            predibase_api_key
            if predibase_api_key
            else os.environ.get("PREDIBASE_API_TOKEN")
        )
        assert predibase_api_key is not None

        self._client = self.initialize_client(predibase_api_key)

        super().__init__(
            model_name=model_name,
            predibase_api_key=predibase_api_key,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            context_window=context_window,
            callback_manager=callback_manager,
            system_prompt=system_prompt,
            messages_to_prompt=messages_to_prompt,
            completion_to_prompt=completion_to_prompt,
            pydantic_program_mode=pydantic_program_mode,
            output_parser=output_parser,
        )

    @staticmethod
    def initialize_client(predibase_api_key: str) -> Any:
        try:
            from predibase import PredibaseClient

            return PredibaseClient(token=predibase_api_key)
        except ImportError as e:
            raise ImportError(
                "Could not import Predibase Python package. "
                "Please install it with `pip install predibase`."
            ) from e
        except ValueError as e:
            raise ValueError("Your API key is not correct. Please try again") from e

    @classmethod
    def class_name(cls) -> str:
        return "PredibaseLLM"

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.max_new_tokens,
            model_name=self.model_name,
        )

    @llm_completion_callback()
    def complete(
        self, prompt: str, formatted: bool = False, **kwargs: Any
    ) -> "CompletionResponse":
        llm = self._client.LLM(f"pb://deployments/{self.model_name}")
        results = llm.prompt(
            prompt, max_new_tokens=self.max_new_tokens, temperature=self.temperature
        )
        return CompletionResponse(text=results.response)

    @llm_completion_callback()
    def stream_complete(
        self, prompt: str, formatted: bool = False, **kwargs: Any
    ) -> "CompletionResponseGen":
        raise NotImplementedError

metadata property #

metadata: LLMMetadata

Get LLM metadata.