Skip to content

Selection

Output parsers.

SelectionOutputParser #

Bases: BaseOutputParser

Source code in llama-index-core/llama_index/core/output_parsers/selection.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class SelectionOutputParser(BaseOutputParser):
    REQUIRED_KEYS = frozenset(Answer.__annotations__)

    def _filter_dict(self, json_dict: dict) -> dict:
        """Filter recursively until a dictionary matches all REQUIRED_KEYS."""
        output_dict = json_dict
        for key, val in json_dict.items():
            if key in self.REQUIRED_KEYS:
                continue
            elif isinstance(val, dict):
                output_dict = self._filter_dict(val)
            elif isinstance(val, list):
                for item in val:
                    if isinstance(item, dict):
                        output_dict = self._filter_dict(item)

        return output_dict

    def _format_output(self, output: List[dict]) -> List[dict]:
        output_json = []
        for json_dict in output:
            valid = True
            for key in self.REQUIRED_KEYS:
                if key not in json_dict:
                    valid = False
                    break

            if not valid:
                json_dict = self._filter_dict(json_dict)

            output_json.append(json_dict)

        return output_json

    def parse(self, output: str) -> Any:
        json_string = _marshal_llm_to_json(output)
        try:
            json_obj = json.loads(json_string)
        except json.JSONDecodeError as e_json:
            try:
                import yaml

                # NOTE: parsing again with pyyaml
                #       pyyaml is less strict, and allows for trailing commas
                #       right now we rely on this since guidance program generates
                #       trailing commas
                json_obj = yaml.safe_load(json_string)
            except yaml.YAMLError as e_yaml:
                raise OutputParserException(
                    f"Got invalid JSON object. Error: {e_json} {e_yaml}. "
                    f"Got JSON string: {json_string}"
                )
            except NameError as exc:
                raise ImportError("Please pip install PyYAML.") from exc

        if isinstance(json_obj, dict):
            json_obj = [json_obj]

        if not json_obj:
            raise ValueError(f"Failed to convert output to JSON: {output!r}")

        json_output = self._format_output(json_obj)
        answers = [Answer.from_dict(json_dict) for json_dict in json_output]
        return StructuredOutput(raw_output=output, parsed_output=answers)

    def format(self, prompt_template: str) -> str:
        return prompt_template + "\n\n" + _escape_curly_braces(FORMAT_STR)