Skip to content

Adapter

LinearAdapterEmbeddingModel module-attribute #

LinearAdapterEmbeddingModel = AdapterEmbeddingModel

AdapterEmbeddingModel #

Bases: BaseEmbedding

Adapter for any embedding model.

This is a wrapper around any embedding model that adds an adapter layer on top of it. This is useful for finetuning an embedding model on a downstream task. The embedding model can be any model - it does not need to expose gradients.

Parameters:

Name Type Description Default
base_embed_model BaseEmbedding

Base embedding model.

required
adapter_path str

Path to adapter.

required
adapter_cls Optional[Type[Any]]

Adapter class. Defaults to None, in which case a linear adapter is used.

None
transform_query bool

Whether to transform query embeddings. Defaults to True.

True
device Optional[str]

Device to use. Defaults to None.

None
embed_batch_size int

Batch size for embedding. Defaults to 10.

DEFAULT_EMBED_BATCH_SIZE
callback_manager Optional[CallbackManager]

Callback manager. Defaults to None.

None
Source code in llama-index-integrations/embeddings/llama-index-embeddings-adapter/llama_index/embeddings/adapter/base.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class AdapterEmbeddingModel(BaseEmbedding):
    """Adapter for any embedding model.

    This is a wrapper around any embedding model that adds an adapter layer \
        on top of it.
    This is useful for finetuning an embedding model on a downstream task.
    The embedding model can be any model - it does not need to expose gradients.

    Args:
        base_embed_model (BaseEmbedding): Base embedding model.
        adapter_path (str): Path to adapter.
        adapter_cls (Optional[Type[Any]]): Adapter class. Defaults to None, in which \
            case a linear adapter is used.
        transform_query (bool): Whether to transform query embeddings. Defaults to True.
        device (Optional[str]): Device to use. Defaults to None.
        embed_batch_size (int): Batch size for embedding. Defaults to 10.
        callback_manager (Optional[CallbackManager]): Callback manager. \
            Defaults to None.

    """

    _base_embed_model: BaseEmbedding = PrivateAttr()
    _adapter: Any = PrivateAttr()
    _transform_query: bool = PrivateAttr()
    _device: Optional[str] = PrivateAttr()
    _target_device: Any = PrivateAttr()

    def __init__(
        self,
        base_embed_model: BaseEmbedding,
        adapter_path: str,
        adapter_cls: Optional[Type[Any]] = None,
        transform_query: bool = True,
        device: Optional[str] = None,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        callback_manager: Optional[CallbackManager] = None,
    ) -> None:
        """Init params."""
        import torch
        from llama_index.embeddings.adapter.utils import BaseAdapter, LinearLayer

        if device is None:
            device = infer_torch_device()
            logger.info(f"Use pytorch device: {device}")
        self._target_device = torch.device(device)

        self._base_embed_model = base_embed_model

        if adapter_cls is None:
            adapter_cls = LinearLayer
        else:
            adapter_cls = cast(Type[BaseAdapter], adapter_cls)

        adapter = adapter_cls.load(adapter_path)
        self._adapter = cast(BaseAdapter, adapter)
        self._adapter.to(self._target_device)

        self._transform_query = transform_query
        super().__init__(
            embed_batch_size=embed_batch_size,
            callback_manager=callback_manager,
            model_name=f"Adapter for {base_embed_model.model_name}",
        )

    @classmethod
    def class_name(cls) -> str:
        return "AdapterEmbeddingModel"

    def _get_query_embedding(self, query: str) -> List[float]:
        """Get query embedding."""
        import torch

        query_embedding = self._base_embed_model._get_query_embedding(query)
        if self._transform_query:
            query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
            query_embedding_t = self._adapter.forward(query_embedding_t)
            query_embedding = query_embedding_t.tolist()

        return query_embedding

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """Get query embedding."""
        import torch

        query_embedding = await self._base_embed_model._aget_query_embedding(query)
        if self._transform_query:
            query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
            query_embedding_t = self._adapter.forward(query_embedding_t)
            query_embedding = query_embedding_t.tolist()

        return query_embedding

    def _get_text_embedding(self, text: str) -> List[float]:
        return self._base_embed_model._get_text_embedding(text)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return await self._base_embed_model._aget_text_embedding(text)