Skip to content

Videodb

VideoDBRetriever #

Bases: BaseRetriever

Source code in llama-index-integrations/retrievers/llama-index-retrievers-videodb/llama_index/retrievers/videodb/base.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class VideoDBRetriever(BaseRetriever):
    def __init__(
        self,
        api_key: Optional[str] = None,
        collection: Optional[str] = "default",
        video: Optional[str] = None,
        score_threshold: Optional[float] = 0.2,
        result_threshold: Optional[int] = 5,
        search_type: Optional[str] = SearchType.semantic,
        base_url: Optional[str] = None,
        callback_manager: Optional[CallbackManager] = None,
    ) -> None:
        """Creates a new VideoDB Retriever."""
        if api_key is None:
            api_key = os.environ.get("VIDEO_DB_API_KEY")
        if api_key is None:
            raise Exception(
                "No API key provided. Set an API key either as an environment variable (VIDEO_DB_API_KEY) or pass it as an argument."
            )
        self._api_key = api_key
        self._base_url = base_url
        self.video = video
        self.collection = collection
        self.score_threshold = score_threshold
        self.result_threshold = result_threshold
        self.search_type = search_type
        super().__init__(callback_manager)

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve."""
        kwargs = {"api_key": self._api_key}
        if self._base_url is not None:
            kwargs["base_url"] = self._base_url
        conn = connect(**kwargs)
        if self.video:
            coll = conn.get_collection(self.collection)
            video = coll.get_video(self.video)
            search_res = video.search(
                query_bundle.query_str,
                search_type=self.search_type,
                score_threshold=self.score_threshold,
                result_threshold=self.result_threshold,
            )
        else:
            coll = conn.get_collection(self.collection)
            search_res = coll.search(
                query_bundle.query_str,
                search_type=self.search_type,
                score_threshold=self.score_threshold,
                result_threshold=self.result_threshold,
            )

        nodes = []
        collection_id = search_res.collection_id
        for shot in search_res.get_shots():
            score = shot.search_score
            textnode = TextNode(
                text=shot.text,
                metadata={
                    "collection_id": collection_id,
                    "video_id": shot.video_id,
                    "length": shot.video_length,
                    "title": shot.video_title,
                    "start": shot.start,
                    "end": shot.end,
                },
            )
            nodes.append(NodeWithScore(node=textnode, score=score))
        return nodes