Skip to content

vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser

logger module-attribute

logger = init_logger(__name__)

Ernie45ToolParser

Bases: ToolParser

Source code in vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
@ToolParserManager.register_module("ernie45")
class Ernie45ToolParser(ToolParser):
    def __init__(self, tokenizer: AnyTokenizer):
        """
        Ernie thinking model format:
        abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
        """
        super().__init__(tokenizer)
        self.current_tool_name_sent = False
        self.prev_tool_call_arr: list[dict] = []
        self.current_tool_id = -1
        self.streamed_args_for_tool: list[str] = []
        self.think_end_token = "</think>"
        self.response_start_token: str = "<response>"
        self.response_end_token: str = "</response>"
        self.tool_call_start_token = "<tool_call>"
        self.tool_call_end_token = "</tool_call>"
        self.tool_calls_start_token = self.tool_call_start_token
        self.newline_token: str = "<0x0A>"

        self.tool_call_regex = re.compile(
            r"<tool_call>\s*(?P<json>\{.*?\})\s*</tool_call>", re.DOTALL
        )

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ToolParser "
                "constructor during construction."
            )

        self.think_end_token_id = self.vocab.get(self.think_end_token)
        self.response_start_token_id = self.vocab.get(self.response_start_token)
        self.response_end_token_id = self.vocab.get(self.response_end_token)
        self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
        self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
        self.newline_token_id = self.vocab.get(self.newline_token)
        self.parser_token_ids = [
            self.think_end_token_id,
            self.response_start_token_id,
            self.response_end_token_id,
        ]

        self._buffer = ""

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        # sanity check; avoid unnecessary processing
        if self.tool_calls_start_token not in model_output:
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

        else:
            try:
                tool_call_json_list = self.tool_call_regex.findall(model_output)

                tool_calls = []
                for tool_call_json in tool_call_json_list:
                    tool_call_dict = json.loads(tool_call_json)
                    args_str = json.dumps(
                        tool_call_dict.get("arguments", {}), ensure_ascii=False
                    )
                    tool_calls.append(
                        ToolCall(
                            type="function",
                            function=FunctionCall(
                                name=tool_call_dict.get("name", ""),
                                arguments=args_str,
                            ),
                        )
                    )

                content = model_output[
                    : model_output.find(self.tool_calls_start_token)
                ].rstrip("\n")
                return ExtractedToolCallInformation(
                    tools_called=True,
                    tool_calls=tool_calls,
                    content=content if content else None,
                )

            except Exception:
                logger.exception("Error in extracting tool call from response.")
                return ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=model_output
                )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> DeltaMessage | None:
        self._buffer += delta_text
        cur_text = self._buffer
        start_idx = cur_text.find(self.tool_call_start_token)
        if start_idx == -1:
            self._buffer = ""
            # At least one toolcall has been completed
            if self.current_tool_id > 0:
                cur_text = ""
            if self.current_tool_id == -1 and all(
                token_id == self.newline_token_id for token_id in previous_token_ids
            ):
                cur_text = cur_text.strip("\n")

            # handle <response> </response> when tool_call is not triggered
            # cur_text === delta_text
            content = cur_text
            if self.response_start_token_id in delta_token_ids:
                content = content.lstrip("\n")
                response_start_idx = content.find(self.response_start_token)
                content = content[response_start_idx + len(self.response_start_token) :]
                # if have </response>, remove it
                response_end_idx = content.rfind(self.response_end_token)
                if response_end_idx != -1:
                    content = content[:response_end_idx]
            elif self.response_end_token_id in delta_token_ids:
                response_end_idx = content.rfind(self.response_end_token)
                content = content[:response_end_idx]
            # remove \n after </think> or <response> or </response>
            if (
                len(previous_token_ids) > 0
                and previous_token_ids[-1] in self.parser_token_ids
            ) and (
                len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
            ):
                content = content.lstrip("\n")

            return DeltaMessage(content=content if content else None)
        logger.debug("cur_text = %s", cur_text)
        end_idx = cur_text.find(self.tool_call_end_token)
        if end_idx != -1:
            if self.current_tool_id == -1:
                self.current_tool_id = 0
                self.prev_tool_call_arr = []
                self.streamed_args_for_tool = []
            while len(self.prev_tool_call_arr) <= self.current_tool_id:
                self.prev_tool_call_arr.append({})
            while len(self.streamed_args_for_tool) <= self.current_tool_id:
                self.streamed_args_for_tool.append("")

            extracted_tool_calls = self.extract_tool_calls(
                cur_text[: end_idx + len(self.tool_call_end_token)], request
            )

            if len(extracted_tool_calls.tool_calls) == 0:
                logger.warning("Failed to extract any tool calls.")
                return None
            tool_call = extracted_tool_calls.tool_calls[0]
            self.prev_tool_call_arr[self.current_tool_id] = {
                "name": tool_call.function.name,
                "arguments": json.loads(tool_call.function.arguments),
            }
            self.streamed_args_for_tool[self.current_tool_id] = (
                tool_call.function.arguments
            )
            delta = DeltaMessage(
                content=extracted_tool_calls.content,
                tool_calls=[
                    DeltaToolCall(
                        index=self.current_tool_id,
                        id=tool_call.id,
                        type=tool_call.type,
                        function=DeltaFunctionCall(
                            name=tool_call.function.name,
                            arguments=tool_call.function.arguments,
                        ),
                    )
                ],
            )
            self.current_tool_id += 1
            self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :]
            return delta

        self._buffer = cur_text[start_idx:]
        content = cur_text[:start_idx].rstrip("\n")
        return DeltaMessage(content=content if content else None)

_buffer instance-attribute

_buffer = ''

current_tool_id instance-attribute

current_tool_id = -1

current_tool_name_sent instance-attribute

current_tool_name_sent = False

newline_token instance-attribute

newline_token: str = '<0x0A>'

newline_token_id instance-attribute

newline_token_id = get(newline_token)

parser_token_ids instance-attribute

parser_token_ids = [
    think_end_token_id,
    response_start_token_id,
    response_end_token_id,
]

prev_tool_call_arr instance-attribute

prev_tool_call_arr: list[dict] = []

response_end_token instance-attribute

response_end_token: str = '</response>'

response_end_token_id instance-attribute

response_end_token_id = get(response_end_token)

response_start_token instance-attribute

response_start_token: str = '<response>'

response_start_token_id instance-attribute

response_start_token_id = get(response_start_token)

streamed_args_for_tool instance-attribute

streamed_args_for_tool: list[str] = []

think_end_token instance-attribute

think_end_token = '</think>'

think_end_token_id instance-attribute

think_end_token_id = get(think_end_token)

tool_call_end_token instance-attribute

tool_call_end_token = '</tool_call>'

tool_call_end_token_id instance-attribute

tool_call_end_token_id = get(tool_call_end_token)

tool_call_regex instance-attribute

tool_call_regex = compile(
    "<tool_call>\\s*(?P<json>\\{.*?\\})\\s*</tool_call>",
    DOTALL,
)

tool_call_start_token instance-attribute

tool_call_start_token = '<tool_call>'

tool_call_start_token_id instance-attribute

tool_call_start_token_id = get(tool_call_start_token)

tool_calls_start_token instance-attribute

tool_calls_start_token = tool_call_start_token

__init__

__init__(tokenizer: AnyTokenizer)
    Ernie thinking model format:
    abc

def

Source code in vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
def __init__(self, tokenizer: AnyTokenizer):
    """
    Ernie thinking model format:
    abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
    """
    super().__init__(tokenizer)
    self.current_tool_name_sent = False
    self.prev_tool_call_arr: list[dict] = []
    self.current_tool_id = -1
    self.streamed_args_for_tool: list[str] = []
    self.think_end_token = "</think>"
    self.response_start_token: str = "<response>"
    self.response_end_token: str = "</response>"
    self.tool_call_start_token = "<tool_call>"
    self.tool_call_end_token = "</tool_call>"
    self.tool_calls_start_token = self.tool_call_start_token
    self.newline_token: str = "<0x0A>"

    self.tool_call_regex = re.compile(
        r"<tool_call>\s*(?P<json>\{.*?\})\s*</tool_call>", re.DOTALL
    )

    if not self.model_tokenizer:
        raise ValueError(
            "The model tokenizer must be passed to the ToolParser "
            "constructor during construction."
        )

    self.think_end_token_id = self.vocab.get(self.think_end_token)
    self.response_start_token_id = self.vocab.get(self.response_start_token)
    self.response_end_token_id = self.vocab.get(self.response_end_token)
    self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
    self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
    self.newline_token_id = self.vocab.get(self.newline_token)
    self.parser_token_ids = [
        self.think_end_token_id,
        self.response_start_token_id,
        self.response_end_token_id,
    ]

    self._buffer = ""

extract_tool_calls

extract_tool_calls(
    model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation
Source code in vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
def extract_tool_calls(
    self,
    model_output: str,
    request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
    # sanity check; avoid unnecessary processing
    if self.tool_calls_start_token not in model_output:
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )

    else:
        try:
            tool_call_json_list = self.tool_call_regex.findall(model_output)

            tool_calls = []
            for tool_call_json in tool_call_json_list:
                tool_call_dict = json.loads(tool_call_json)
                args_str = json.dumps(
                    tool_call_dict.get("arguments", {}), ensure_ascii=False
                )
                tool_calls.append(
                    ToolCall(
                        type="function",
                        function=FunctionCall(
                            name=tool_call_dict.get("name", ""),
                            arguments=args_str,
                        ),
                    )
                )

            content = model_output[
                : model_output.find(self.tool_calls_start_token)
            ].rstrip("\n")
            return ExtractedToolCallInformation(
                tools_called=True,
                tool_calls=tool_calls,
                content=content if content else None,
            )

        except Exception:
            logger.exception("Error in extracting tool call from response.")
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

extract_tool_calls_streaming

extract_tool_calls_streaming(
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
    request: ChatCompletionRequest,
) -> DeltaMessage | None
Source code in vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
def extract_tool_calls_streaming(
    self,
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
    request: ChatCompletionRequest,
) -> DeltaMessage | None:
    self._buffer += delta_text
    cur_text = self._buffer
    start_idx = cur_text.find(self.tool_call_start_token)
    if start_idx == -1:
        self._buffer = ""
        # At least one toolcall has been completed
        if self.current_tool_id > 0:
            cur_text = ""
        if self.current_tool_id == -1 and all(
            token_id == self.newline_token_id for token_id in previous_token_ids
        ):
            cur_text = cur_text.strip("\n")

        # handle <response> </response> when tool_call is not triggered
        # cur_text === delta_text
        content = cur_text
        if self.response_start_token_id in delta_token_ids:
            content = content.lstrip("\n")
            response_start_idx = content.find(self.response_start_token)
            content = content[response_start_idx + len(self.response_start_token) :]
            # if have </response>, remove it
            response_end_idx = content.rfind(self.response_end_token)
            if response_end_idx != -1:
                content = content[:response_end_idx]
        elif self.response_end_token_id in delta_token_ids:
            response_end_idx = content.rfind(self.response_end_token)
            content = content[:response_end_idx]
        # remove \n after </think> or <response> or </response>
        if (
            len(previous_token_ids) > 0
            and previous_token_ids[-1] in self.parser_token_ids
        ) and (
            len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
        ):
            content = content.lstrip("\n")

        return DeltaMessage(content=content if content else None)
    logger.debug("cur_text = %s", cur_text)
    end_idx = cur_text.find(self.tool_call_end_token)
    if end_idx != -1:
        if self.current_tool_id == -1:
            self.current_tool_id = 0
            self.prev_tool_call_arr = []
            self.streamed_args_for_tool = []
        while len(self.prev_tool_call_arr) <= self.current_tool_id:
            self.prev_tool_call_arr.append({})
        while len(self.streamed_args_for_tool) <= self.current_tool_id:
            self.streamed_args_for_tool.append("")

        extracted_tool_calls = self.extract_tool_calls(
            cur_text[: end_idx + len(self.tool_call_end_token)], request
        )

        if len(extracted_tool_calls.tool_calls) == 0:
            logger.warning("Failed to extract any tool calls.")
            return None
        tool_call = extracted_tool_calls.tool_calls[0]
        self.prev_tool_call_arr[self.current_tool_id] = {
            "name": tool_call.function.name,
            "arguments": json.loads(tool_call.function.arguments),
        }
        self.streamed_args_for_tool[self.current_tool_id] = (
            tool_call.function.arguments
        )
        delta = DeltaMessage(
            content=extracted_tool_calls.content,
            tool_calls=[
                DeltaToolCall(
                    index=self.current_tool_id,
                    id=tool_call.id,
                    type=tool_call.type,
                    function=DeltaFunctionCall(
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    ),
                )
            ],
        )
        self.current_tool_id += 1
        self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :]
        return delta

    self._buffer = cur_text[start_idx:]
    content = cur_text[:start_idx].rstrip("\n")
    return DeltaMessage(content=content if content else None)