#!/usr/bin/env python3
#
# SPDX-License-Identifier: MIT

"""
EI protocol parser

This parser is intended to be generically useful for language bindings
other than libei/libeis. If it isn't, please file a bug.

When used as ei-scanner, it converts a Jinja2 template with the
scanned protocol. Otherwise, use the `parse()` function to
parse the protocol and return its structure as a set of Python
classes.

Opcodes for events and request are assigned in order as they
appear in the XML file.
"""

from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path
from textwrap import dedent
from dataclasses import dataclass, field

import argparse
import jinja2
import jinja2.environment
import os
import sys
import xml.sax
import xml.sax.handler

"""
Mapping of allowed protocol types to the single-character signature strings
used in the various code pieces.
"""
PROTOCOL_TYPES = {
    "uint32": "u",
    "int32": "i",
    "uint64": "t",
    "int64": "x",
    "float": "f",
    "fd": "h",
    "new_id": "n",
    "object": "o",
    "string": "s",
}


def snake2camel(s: str) -> str:
    """
    Convert snake_case to CamelCase (well, strictly speaking
    PascalCase
    """
    return s.replace("_", " ").title().replace(" ", "")


@dataclass
class Description:
    summary: str = ""
    text: str = ""


@dataclass
class Argument:
    """
    Argument to a request or a reply
    """

    name: str
    protocol_type: str
    summary: str
    enum: Optional["Enum"]
    interface: Optional["Interface"]
    interface_arg: Optional["Argument"] = None
    """
    For an argument with "interface_arg", this field points to the argument that
    contains the interface name.
    """
    interface_arg_for: Optional["Argument"] = None
    """
    For an argument referenced by another argument through "interface_name", this field
    points to the other argument that references this argument.
    """
    version_arg: Optional["Argument"] = None
    """
    For an argument with type "new_id", this field points to the argument that
    contains the version for this new object.
    """
    version_arg_for: Optional["Argument"] = None
    """
    For an argument referenced by another argument of type "new_id", this field
    points to the other argument that references this argument.
    """
    allow_null: bool = False
    """
    For an argument of type string, specify if the argument may be NULL.
    """

    def __post_init(self):
        if self.protocol_type is None or self.protocol_type not in PROTOCOL_TYPES:
            raise ValueError(f"Failed to parse protocol_type {self.protocol_type}")
        if self.interface is not None and self.signature not in ["n", "o"]:
            raise ValueError("Interface may only be set for object types")

    @property
    def signature(self) -> str:
        """
        The single-character signature for this argument
        """
        return PROTOCOL_TYPES[self.protocol_type]

    @property
    def as_c_arg(self) -> str:
        return f"{self.c_type} {self.name}"

    @property
    def c_type(self) -> str:
        return {
            "uint32": "uint32_t",
            "int32": "int32_t",
            "uint64": "uint64_t",
            "int64": "int64_t",
            "string": "const char *",
            "fd": "int",
            "float": "float",
            "object": "object_id_t",
            "new_id": "new_id_t",
        }[self.protocol_type]

    @classmethod
    def create(
        cls,
        name: str,
        protocol_type: str,
        summary: str = "",
        enum: Optional["Enum"] = None,
        interface: Optional["Interface"] = None,
        allow_null: bool = False,
    ) -> "Argument":
        return cls(
            name=name,
            protocol_type=protocol_type,
            summary=summary,
            enum=enum,
            interface=interface,
            allow_null=allow_null,
        )


@dataclass
class Message:
    """
    Parent class for a wire message (Request or Event).
    """

    name: str
    since: int
    opcode: int
    interface: "Interface"
    description: Optional[Description] = None
    is_destructor: bool = False
    context_type: Optional[str] = None

    arguments: List[Argument] = field(init=False, default_factory=list)

    def __post_init(self):
        if self.context_type not in [None, "sender", "receiver"]:
            raise ValueError(f"Invalid context type {self.context_type}")

    def add_argument(self, arg: Argument) -> None:
        if arg.name in [a.name for a in self.arguments]:
            raise ValueError(f"Duplicate argument name '{arg.name}'")
        self.arguments.append(arg)

    @property
    def num_arguments(self) -> int:
        return len(self.arguments)

    @property
    def signature(self) -> str:
        return "".join([a.signature for a in self.arguments])

    @property
    def camel_name(self) -> str:
        return snake2camel(self.name)

    def find_argument(self, name: str) -> Optional[Argument]:
        for a in self.arguments:
            if a.name == name:
                return a
        return None


@dataclass
class Request(Message):
    @classmethod
    def create(
        cls,
        name: str,
        opcode: int,
        interface: "Interface",
        since: int = 1,
        is_destructor: bool = False,
    ) -> "Request":
        return cls(
            name=name,
            opcode=opcode,
            since=since,
            interface=interface,
            is_destructor=is_destructor,
        )

    @property
    def fqdn(self) -> str:
        """
        The full name of this Request as <interface name>_request_<request name>
        """
        return f"{self.interface.name}_request_{self.name}"


@dataclass
class Event(Message):
    @classmethod
    def create(
        cls,
        name: str,
        opcode: int,
        interface: "Interface",
        since: int = 1,
        is_destructor: bool = False,
    ) -> "Event":
        return cls(
            name=name,
            opcode=opcode,
            since=since,
            interface=interface,
            is_destructor=is_destructor,
        )

    @property
    def fqdn(self) -> str:
        """
        The full name of this Event as <interface name>_event_<request name>
        """
        return f"{self.interface.name}_event_{self.name}"


@dataclass
class Entry:
    """
    An enum entry
    """

    name: str
    value: int
    enum: "Enum"
    summary: str
    since: int

    @classmethod
    def create(
        cls, name: str, value: int, enum: "Enum", summary: str = "", since: int = 1
    ) -> "Entry":
        return cls(name=name, value=value, enum=enum, summary=summary, since=since)

    @property
    def fqdn(self) -> str:
        """
        The full name of this Entry as <interface name>_<enum name>_<entry name>
        """
        return f"{self.enum.fqdn}_{self.name}"


@dataclass
class Enum:
    name: str
    since: int
    interface: "Interface"
    is_bitfield: bool = False
    description: Optional[Description] = None

    entries: List[Entry] = field(init=False, default_factory=list)

    @classmethod
    def create(
        cls,
        name: str,
        interface: "Interface",
        since: int = 1,
        is_bitfield: bool = False,
    ) -> "Enum":
        return cls(name=name, since=since, interface=interface, is_bitfield=is_bitfield)

    def add_entry(self, entry: Entry) -> None:
        for e in self.entries:
            if e.name == entry.name:
                raise ValueError(f"Duplicate enum name '{entry.name}'")

            if e.value == entry.value:
                raise ValueError(f"Duplicate enum value '{entry.value}'")

            if self.is_bitfield:
                if e.value < 0:
                    raise ValueError("Bitmasks must not be less than zero")
                try:
                    if e.value.bit_count() > 1:
                        raise ValueError("Bitmasks must have exactly one bit set")
                except AttributeError:
                    pass  # bit_count() requires Python 3.10

        self.entries.append(entry)

    @property
    def fqdn(self):
        """
        The full name of this Enum as <interface name>_<enum name>
        """
        return f"{self.interface.name}_{self.name}"

    @property
    def camel_name(self) -> str:
        return snake2camel(self.name)


@dataclass
class Interface:
    protocol_name: str  # name as in the XML, e.g. ei_pointer
    version: int
    requests: List[Request] = field(init=False, default_factory=list)
    events: List[Event] = field(init=False, default_factory=list)
    enums: List[Enum] = field(init=False, default_factory=list)

    mode: str
    description: Optional[Description] = None

    def __post_init(self):
        if self.mode not in ["ei", "eis", "brei"]:
            raise ValueError(f"Invalid mode {self.mode}")

    @property
    def name(self) -> str:
        """
        Returns the mode-adjusted name of the interface, i.e. this may return
        "ei_pointer", "eis_pointer", "brei_pointer", etc. depending on the
        mode.
        """
        return Interface.mangle_name(self.protocol_name, self.mode)

    @property
    def plainname(self) -> str:
        """
        Returns the plain name of the interface, i.e. this returns
        "pointer", "handshake", etc. without the "ei_" or "eis_" prefix.
        """
        if self.protocol_name.startswith("ei_"):
            return f"{self.protocol_name[3:]}"
        return self.protocol_name

    @staticmethod
    def mangle_name(name: str, component: str) -> str:
        """
        Returns the mangled interface name with the component as prefix (e.g. eis_device).
        The XML only uses `ei_` as prefix, so let's replace that accordingly.
        """
        if name.startswith("ei"):
            return f"{component}{name[2:]}"
        return name

    def add_request(self, request: Request) -> None:
        if request.name in [r.name for r in self.requests]:
            raise ValueError(f"Duplicate request name '{request.name}'")
        self.requests.append(request)

    def add_event(self, event: Event) -> None:
        if event.name in [r.name for r in self.events]:
            raise ValueError(f"Duplicate event name '{event.name}'")
        self.events.append(event)

    def add_enum(self, enum: Enum) -> None:
        if enum.name in [r.name for r in self.enums]:
            raise ValueError(f"Duplicate enum name '{enum.name}'")
        self.enums.append(enum)

    def find_enum(self, name: str) -> Optional[Enum]:
        for e in self.enums:
            if e.name == name:
                return e
        return None

    @property
    def outgoing(self) -> List[Message]:
        """
        Returns the list of messages outgoing from this implementation.

        We use the same class for both ei and eis. To make the
        template simpler, the class maps requests/events to
        incoming/outgoing as correct relative to the implementation.
        """
        if self.mode == "ei":
            return self.requests  # type: ignore
        elif self.mode == "eis":
            return self.events  # type: ignore
        else:
            raise NotImplementedError(
                f"Interface.outgoing is not supported for mode {self.mode}"
            )

    @property
    def incoming(self) -> List[Message]:
        """
        Returns the list of messages incoming to this implementation.

        We use the same class for both ei and eis. To make the
        template simpler, the class maps requests/events to
        incoming/outgoing as correct relative to the implementation.
        """
        if self.mode == "ei":
            return self.events  # type: ignore
        elif self.mode == "eis":
            return self.requests  # type: ignore
        else:
            raise NotImplementedError(
                f"Interface.incoming is not supported for mode {self.mode}"
            )

    @property
    def c_type(self) -> str:
        return f"struct {self.name} *"

    @property
    def as_c_arg(self) -> str:
        return f"{self.c_type} {self.name}"

    @property
    def camel_name(self) -> str:
        return snake2camel(self.name)

    @classmethod
    def create(cls, protocol_name: str, version: int, mode: str = "ei") -> "Interface":
        assert mode in ["ei", "eis", "brei"]
        return cls(protocol_name=protocol_name, version=version, mode=mode)


@dataclass
class XmlError(Exception):
    line: int
    column: int
    message: str

    def __str__(self) -> str:
        return f"line {self.line}:{self.column}: {self.message}"

    @classmethod
    def create(cls, message: str, location: Tuple[int, int] = (0, 0)) -> "XmlError":
        return cls(line=location[0], column=location[1], message=message)


@dataclass
class Copyright:
    text: str = ""
    is_complete: bool = field(init=False, default=False)


@dataclass
class Protocol:
    copyright: Optional[str] = None
    interfaces: List[Interface] = field(default_factory=list)


@dataclass
class ProtocolParser(xml.sax.handler.ContentHandler):
    component: str
    interfaces: List[Interface] = field(default_factory=list)
    copyright: Optional[Copyright] = field(init=False, default=None)

    current_interface: Optional[Interface] = field(init=False, default=None)
    current_message: Optional[Union[Message, Enum]] = field(init=False, default=None)
    current_description: Optional[Description] = field(init=False, default=None)
    # A dict of arg name to interface_arg name mappings
    current_interface_arg_names: Dict[str, str] = field(
        init=False, default_factory=dict
    )
    current_new_id_arg: Optional[Argument] = field(init=False, default=None)

    _run_counter: int = field(init=False, default=0, repr=False)

    @property
    def location(self) -> Tuple[int, int]:
        line = self._locator.getLineNumber()  # type: ignore
        col = self._locator.getColumnNumber()  # type: ignore
        return line, col

    def interface_by_name(self, protocol_name: str) -> Interface:
        """
        Look up an interface by its protocol name (i.e. always "ei_foo", regardless of
        what we're generating).
        """
        try:
            return [
                iface
                for iface in self.interfaces
                if iface.protocol_name == protocol_name
            ].pop()
        except IndexError:
            raise XmlError.create(
                f"Unable to find interface {protocol_name}", self.location
            )

    def startDocument(self):
        self._run_counter += 1

    def startElement(self, element: str, attrs: dict):
        if element == "interface":
            if self.current_interface is not None:
                raise XmlError.create(
                    f"Invalid element '{element}' inside interface '{self.current_interface.name}'",
                    self.location,
                )

            try:
                name = attrs["name"]
                version = attrs["version"]
            except KeyError as e:
                raise XmlError.create(
                    f"Missing attribute {e} in element '{element}'",
                    self.location,
                )

            protocol_name = name
            # We only create the interface on the first run, in subsequent runs we
            # re-use them so we can cross reference correctly
            if self._run_counter > 1:
                intf = self.interface_by_name(protocol_name)
            else:
                intf = Interface.create(
                    protocol_name=protocol_name,
                    version=version,
                    mode=self.component,
                )
                self.interfaces.append(intf)

            self.current_interface = intf

        # first run only parses interfaces
        if self._run_counter <= 1:
            return

        if element == "enum":
            if self.current_interface is None:
                raise XmlError.create(
                    f"Invalid element '{element}' outside an <interface>",
                    self.location,
                )
            if self.current_message is not None:
                raise XmlError.create(
                    f"Invalid element '{element}' inside '{self.current_message.name}'",
                    self.location,
                )
            try:
                name = attrs["name"]
                since = attrs["since"]
            except KeyError as e:
                raise XmlError.create(
                    f"Missing attribute {e} in element '{element}'",
                    self.location,
                )

            try:
                is_bitfield = {
                    "true": True,
                    "false": False,
                }[attrs.get("bitfield", "false")]
            except KeyError as e:
                raise XmlError.create(
                    f"Invalid value {e} for boolean bitfield attribute in '{element}'",
                    self.location,
                )

            # We only create the enum on the second run, in subsequent runs
            # we re-use them so we can cross-reference correctly
            if self._run_counter > 2:
                enum = self.current_interface.find_enum(name)
                if enum is None:
                    raise XmlError.create(
                        f"Invalid enum {name}. This is a parser bug",
                        self.location,
                    )
            else:
                enum = Enum.create(
                    name=name,
                    since=since,
                    interface=self.current_interface,
                    is_bitfield=is_bitfield,
                )
                try:
                    self.current_interface.add_enum(enum)
                except ValueError as e:
                    raise XmlError.create(str(e), self.location)
            self.current_message = enum

        # second run only parses enums
        if self._run_counter <= 2:
            return

        if element == "request":
            if self.current_interface is None:
                raise XmlError.create(
                    f"Invalid element '{element}' outside an <interface>",
                    self.location,
                )

            try:
                name = attrs["name"]
                since = attrs["since"]
            except KeyError as e:
                raise XmlError.create(
                    f"Missing attribute {e} in element '{element}'",
                    self.location,
                )
            is_destructor = attrs.get("type", "") == "destructor"
            opcode = len(self.current_interface.requests)
            request = Request.create(
                name=name,
                since=since,
                opcode=opcode,
                interface=self.current_interface,
                is_destructor=is_destructor,
            )
            request.context_type = attrs.get("context-type")
            try:
                self.current_interface.add_request(request)
            except ValueError as e:
                raise XmlError.create(str(e), self.location)
            self.current_message = request
        elif element == "event":
            if self.current_interface is None:
                raise XmlError.create(
                    f"Invalid element '{element}' outside an <interface>",
                    self.location,
                )
            if self.current_message is not None:
                raise XmlError.create(
                    f"Invalid element '{element}' inside '{self.current_message.name}'",
                    self.location,
                )
            try:
                name = attrs["name"]
                since = attrs["since"]
            except KeyError as e:
                raise XmlError.create(
                    f"Missing attribute {e} in element '{element}'",
                    self.location,
                )

            is_destructor = attrs.get("type", "") == "destructor"
            opcode = len(self.current_interface.events)
            event = Event.create(
                name=name,
                since=since,
                opcode=opcode,
                interface=self.current_interface,
                is_destructor=is_destructor,
            )
            event.context_type = attrs.get("context-type")
            try:
                self.current_interface.add_event(event)
            except ValueError as e:
                raise XmlError.create(str(e), self.location)
            self.current_message = event
        elif element == "arg":
            if self.current_interface is None:
                raise XmlError.create(
                    f"Invalid element '{element}' outside an <interface>",
                    self.location,
                )
            if not isinstance(self.current_message, Message):
                raise XmlError.create(
                    f"Invalid element '{element}' must be inside <request> or <event>",
                    self.location,
                )
            name = attrs["name"]
            proto_type = attrs["type"]
            if proto_type not in PROTOCOL_TYPES:
                raise XmlError.create(
                    f"Invalid type '{proto_type}' for '{self.current_interface.name}.{self.current_message.name}::{name}'",
                    self.location,
                )

            summary = attrs.get("summary", "")
            interface_name = attrs.get("interface", None)
            if interface_name is not None:
                interface = self.interface_by_name(interface_name)
            else:
                interface = None

            # interface_arg is set to the name of some other arg that specifies the actual
            # interface name for this argument
            interface_arg_name = attrs.get("interface_arg", None)
            if interface_arg_name is not None:
                self.current_interface_arg_names[name] = interface_arg_name

            enum_name = attrs.get("enum", None)
            enum = None
            if enum_name is not None:
                if "." in enum_name:
                    iname, enum_name = enum_name.split(".")
                    intf = self.interface_by_name(iname)
                else:
                    intf = self.current_interface

                enum = intf.find_enum(enum_name)
                if enum is None:
                    raise XmlError.create(
                        f"Failed to find enum '{intf.name}.{enum_name}'",
                        self.location,
                    )

            allow_null = attrs.get("allow-null", "false") == "true"
            arg = Argument.create(
                name=name,
                protocol_type=proto_type,
                summary=summary,
                enum=enum,
                interface=interface,
                allow_null=allow_null,
            )
            self.current_message.add_argument(arg)
            if proto_type == "new_id":
                if self.current_new_id_arg is not None:
                    raise XmlError.create(
                        f"Multiple args of type '{proto_type}' for '{self.current_interface.name}.{self.current_message.name}'",
                        self.location,
                    )
                self.current_new_id_arg = arg
        elif element == "entry":
            if self.current_interface is None:
                raise XmlError.create(
                    f"Invalid element '{element}' outside an <interface>",
                    self.location,
                )
            if not isinstance(self.current_message, Enum):
                raise XmlError.create(
                    f"Invalid element '{element}' must be inside <enum>",
                    self.location,
                )
            name = attrs["name"]
            value = int(attrs["value"])
            summary = attrs.get("summary", "")
            since = int(attrs.get("since", 1))
            entry = Entry.create(
                name=name,
                value=value,
                enum=self.current_message,
                summary=summary,
                since=since,
            )
            try:
                self.current_message.add_entry(entry)
            except ValueError as e:
                raise XmlError.create(str(e), self.location)
        elif element == "description":
            summary = attrs.get("summary", "")
            self.current_description = Description(summary=summary)
        elif element == "copyright":
            if self.copyright is not None:
                raise XmlError.create(
                    "Multiple <copyright> tags in file", self.location
                )
            self.copyright = Copyright()

    def endElement(self, name):
        if name == "interface":
            assert self.current_interface is not None
            self.current_interface = None

        # first run only parses interfaces
        if self._run_counter <= 1:
            return

        if name == "enum":
            assert isinstance(self.current_message, Enum)
            self.current_message = None

        # second run only parses interfaces and enums
        if self._run_counter <= 2:
            return

        # Populate `interface_arg` and `interface_arg_for`, now we have all arguments
        if name in ["request", "event"]:
            assert isinstance(self.current_message, Message)
            assert isinstance(self.current_interface, Interface)
            # obj is the argument of type object that the interface applies to
            # iname is the argument of type "interface_name" that specifies the interface
            for obj, iname in self.current_interface_arg_names.items():
                obj_arg = self.current_message.find_argument(obj)
                iname_arg = self.current_message.find_argument(iname)

                assert obj_arg is not None
                assert iname_arg is not None

                obj_arg.interface_arg = iname_arg
                iname_arg.interface_arg_for = obj_arg
            self.current_interface_arg_names = {}

            if self.current_new_id_arg is not None:
                arg = self.current_new_id_arg
                version_arg = self.current_message.find_argument("version")
                if version_arg is None:
                    # Sigh, protocol bug: ei_connection.sync one doesn't have a version arg
                    if (
                        f"{self.current_interface.plainname}.{self.current_message.name}"
                        != "connection.sync"
                    ):
                        raise XmlError.create(
                            f"Unable to find a version argument for {self.current_interface.plainname}.{self.current_message.name}::{arg.name}",
                            self.location,
                        )
                else:
                    arg.version_arg = version_arg
                    version_arg.version_arg_for = arg
                self.current_new_id_arg = None
        if name == "request":
            assert isinstance(self.current_message, Request)
            self.current_message = None
        elif name == "event":
            assert isinstance(self.current_message, Event)
            self.current_message = None
        elif name == "description":
            assert self.current_description is not None
            self.current_description.text = dedent(self.current_description.text)
            if self.current_message is None:
                assert self.current_interface is not None
                self.current_interface.description = self.current_description
            else:
                self.current_message.description = self.current_description
            self.current_description = None
        elif name == "copyright":
            assert self.copyright is not None
            self.copyright.text = dedent(self.copyright.text)
            self.copyright.is_complete = True

    def characters(self, content):
        if self.current_description is not None:
            self.current_description.text += content
        elif self.copyright is not None and not self.copyright.is_complete:
            self.copyright.text += content

    @classmethod
    def create(cls, component: str) -> "ProtocolParser":
        h = cls(component=component)
        return h


def parse(protofile: Path, component: str) -> Protocol:
    proto = ProtocolParser.create(component=component)
    xml.sax.parse(os.fspath(protofile), proto)
    # We parse three times, once to fetch all the interfaces, one for enums, then to parse the details
    xml.sax.parse(os.fspath(protofile), proto)
    xml.sax.parse(os.fspath(protofile), proto)
    copyright = proto.copyright.text if proto.copyright else None
    return Protocol(
        copyright=copyright,
        interfaces=proto.interfaces,
    )


def generate_source(
    proto: Protocol, template: str, component: str, extra_data: Optional[dict]
) -> jinja2.environment.TemplateStream:
    assert component in ["ei", "eis", "brei"]

    data: dict[str, Any] = {}
    data["component"] = component
    data["interfaces"] = proto.interfaces
    data["extra"] = extra_data

    loader: jinja2.BaseLoader
    if template == "-":
        loader = jinja2.FunctionLoader(lambda _: sys.stdin.read())
        filename = "<stdin>"
    else:
        path = Path(template)
        assert path.exists(), f"Failed to find template {path}"
        filename = path.name
        loader = jinja2.FileSystemLoader(os.fspath(path.parent))

    env = jinja2.Environment(
        loader=loader,
        trim_blocks=True,
        lstrip_blocks=True,
    )

    # jinja filter to convert foo into "struct foo *"
    def filter_c_type(name):
        return f"struct {name} *"

    # jinja filter to convert foo into "struct foo *foo"
    def filter_as_c_arg(name):
        return f"struct {name} *{name}"

    # escape any ei[s]?_foo.bar with markdown backticks
    def filter_ei_escape_names(str, quotes="`"):
        if not str:
            return str

        import re

        return re.sub(
            rf"({component}[_-]\w*)(\.[.\w]*)?", rf"{quotes}\1\2{quotes}", str
        )

    env.filters["c_type"] = filter_c_type
    env.filters["as_c_arg"] = filter_as_c_arg
    env.filters["camel"] = snake2camel
    env.filters["ei_escape_names"] = filter_ei_escape_names
    jtemplate = env.get_template(filename)
    return jtemplate.stream(data)


def scanner(argv: list[str]) -> None:
    parser = argparse.ArgumentParser(
        description=dedent(
            """
    ei-scanner is a tool to parse the EI protocol description XML and
    pass the data to a Jinja2 template. That template can then be
    used to generate protocol bindings for the desired language.

    typical usages:
         ei-scanner --component=ei protocol.xml my-template.tpl
         ei-scanner --component=eis --output=bindings.rs protocol.xml bindings.rs.tpl

    Elements in the XML file are provided as variables with attributes
    generally matching the XML file. For example, each interface has requests,
    events and enums, and each of those has a name.

    ei-scanner additionally provides the following values to the Jinja2 templates:
        - interface.incoming and interface.outgoing: maps to the requests/events of
          the interface, depending on the component.
        - argument.signature: a single-character signature type mapping
          from the protocol XML type:
            uint32 -> "u"
            int32 -> "i"
            float -> "f"
            fd -> "h"
            new_id -> "n"
            object -> "o"
            string -> "s"

    ei-scanner adds the following Jinja2 filters for convenience:
        {{foo|c_type}} ... resolves to "struct foo *"
        {{foo|as_c_arg}} ... resolves to "struct foo *foo"
        {{foo_bar|camel}} ... resolves to "FooBar"

    """
        ),
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    parser.add_argument(
        "--component", type=str, choices=["ei", "eis", "brei"], default="ei"
    )
    parser.add_argument(
        "--output", type=str, default="-", help="Output file to write to"
    )
    parser.add_argument("protocol", type=Path, help="The protocol XML file")
    parser.add_argument(
        "--jinja-extra-data",
        type=str,
        help="Extra data (in JSON format) to pass through to the Jinja template as 'extra'",
        default=None,
    )
    parser.add_argument(
        "--jinja-extra-data-file",
        type=Path,
        help="Path to file with extra data to pass through to the Jinja template as 'extra'",
        default=None,
    )
    parser.add_argument(
        "template", type=str, help="The Jinja2 compatible template file"
    )

    ns = parser.parse_args(argv)
    assert ns.protocol.exists()

    try:
        proto = parse(
            protofile=ns.protocol,
            component=ns.component,
        )
    except xml.sax.SAXParseException as e:
        print(f"Parser error: {e}", file=sys.stderr)
        raise SystemExit(1)
    except XmlError as e:
        print(f"Protocol XML error: {e}", file=sys.stderr)
        raise SystemExit(1)

    if ns.jinja_extra_data is not None:
        import json

        extra_data = json.loads(ns.jinja_extra_data)
    elif ns.jinja_extra_data_file is not None:
        if ns.jinja_extra_data_file.name.endswith(
            ".yml"
        ) or ns.jinja_extra_data_file.name.endswith(".yaml"):
            import yaml

            with open(ns.jinja_extra_data_file) as fd:
                extra_data = yaml.safe_load(fd)
        elif ns.jinja_extra_data_file.name.endswith(".json"):
            import json

            with open(ns.jinja_extra_data_file) as fd:
                extra_data = json.load(fd)
        else:
            print("Unknown file format for jinja data", file=sys.stderr)
            raise SystemExit(1)
    else:
        extra_data = None

    stream = generate_source(
        proto=proto, template=ns.template, component=ns.component, extra_data=extra_data
    )

    file = sys.stdout if ns.output == "-" else open(ns.output, "w")
    stream.dump(file)


if __name__ == "__main__":
    scanner(sys.argv[1:])
