Skip to content

Colbert rerank

ColbertRerank #

Bases: BaseNodePostprocessor

Source code in llama-index-integrations/postprocessor/llama-index-postprocessor-colbert-rerank/llama_index/postprocessor/colbert_rerank/base.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class ColbertRerank(BaseNodePostprocessor):
    model: str = Field(description="Colbert model name.")
    top_n: int = Field(description="Number of nodes to return sorted by score.")
    device: str = Field(
        default="cpu",
        description="Device to use for sentence transformer.",
    )
    keep_retrieval_score: bool = Field(
        default=False,
        description="Whether to keep the retrieval score in metadata.",
    )
    _model: Any = PrivateAttr()
    _tokenizer: Any = PrivateAttr()

    def __init__(
        self,
        top_n: int = 5,
        model: str = "colbert-ir/colbertv2.0",
        tokenizer: str = "colbert-ir/colbertv2.0",
        device: Optional[str] = None,
        keep_retrieval_score: Optional[bool] = False,
    ):
        device = infer_torch_device() if device is None else device
        self._tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        self._model = AutoModel.from_pretrained(model)
        super().__init__(
            top_n=top_n,
            model=model,
            tokenizer=tokenizer,
            device=device,
            keep_retrieval_score=keep_retrieval_score,
        )

    @classmethod
    def class_name(cls) -> str:
        return "ColbertRerank"

    def _calculate_sim(self, query: str, documents_text_list: List[str]) -> List[float]:
        # Query: [batch_size, query_length, embedding_size] -> [batch_size, query_length, 1, embedding_size]
        # Document: [batch_size, doc_length, embedding_size] -> [batch_size, 1, doc_length, embedding_size]
        query_encoding = self._tokenizer(query, return_tensors="pt")
        query_embedding = self._model(**query_encoding).last_hidden_state
        rerank_score_list = []

        for document_text in documents_text_list:
            document_encoding = self._tokenizer(
                document_text, return_tensors="pt", truncation=True, max_length=512
            )
            document_embedding = self._model(**document_encoding).last_hidden_state

            sim_matrix = torch.nn.functional.cosine_similarity(
                query_embedding.unsqueeze(2), document_embedding.unsqueeze(1), dim=-1
            )

            # Take the maximum similarity for each query token (across all document tokens)
            # sim_matrix shape: [batch_size, query_length, doc_length]
            max_sim_scores, _ = torch.max(sim_matrix, dim=2)
            rerank_score_list.append(torch.mean(max_sim_scores, dim=1))

        return rerank_score_list

    def _postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        if query_bundle is None:
            raise ValueError("Missing query bundle in extra info.")
        if len(nodes) == 0:
            return []

        nodes_text_list = [
            str(node.node.get_content(metadata_mode=MetadataMode.EMBED))
            for node in nodes
        ]

        with self.callback_manager.event(
            CBEventType.RERANKING,
            payload={
                EventPayload.NODES: nodes,
                EventPayload.MODEL_NAME: self.model,
                EventPayload.QUERY_STR: query_bundle.query_str,
                EventPayload.TOP_K: self.top_n,
            },
        ) as event:
            scores = self._calculate_sim(query_bundle.query_str, nodes_text_list)

            assert len(scores) == len(nodes)

            for node, score in zip(nodes, scores):
                if self.keep_retrieval_score:
                    # keep the retrieval score in metadata
                    node.node.metadata["retrieval_score"] = node.score
                node.score = float(score)

            reranked_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
                : self.top_n
            ]
            event.on_end(payload={EventPayload.NODES: reranked_nodes})

        return reranked_nodes