#!/usr/bin/env python3

# Copyright (c) 2020-now by the Zeek Project. See LICENSE for details.
#
# Turn the output spicy-doc into reST.

import argparse
import copy
import filecmp
import json
import os
import os.path
import re
import sys
import textwrap
from re import Pattern


def fatalError(message: str):
    print(message, file=sys.stderr)
    sys.exit(1)


def call(name):
    def _(op):
        return f"{name}({op.operands[0].rst(in_operator=True, markup=False)})"

    return _


def keyword(name):
    def _(op):
        return f"{name} <sp> {op.operands[0].rst(in_operator=True)}"

    return _


def unary(prefix, postfix=""):
    def _(op):
        return f"op:{prefix} {op.operands[0].rst(in_operator=True)} op:{postfix}"

    return _


def binary(token):
    def _(op):
        op1 = op.operands[0].rst(in_operator=True)
        op2 = op.operands[1].rst(in_operator=True)

        commutative = ""
        if op.commutative and op1 != op2:
            commutative = " $commutative$"

        return f"{op1} <sp> op:{token} <sp> {op2}{commutative}"

    return _


Operators = {
    "Add": lambda op: f"add <sp> {op.operands[0].rst(in_operator=True)}[{op.operands[1].rst(in_operator=True, markup=False)}]",
    "Begin": call("begin"),
    "BitAnd": binary("&"),
    "BitOr": binary("|"),
    "BitXor": binary("^"),
    "End": call("end"),
    "Call": lambda op: "{}({})".format(
        op.operands[0].rst(in_operator=True, markup=False),
        ", ".join(arg.rst(in_operator=True, markup=False) for arg in op.operands[1:]),
    ),
    "Cast": lambda op: "cast<{}>({})".format(
        TypedType.sub("\\1", op.operands[1].rst(in_operator=True, markup=False)),
        op.operands[0].rst(in_operator=True, markup=False),
    ),
    "CustomAssign": lambda op: f"{op.operands[0].rst(in_operator=True)} = {op.operands[1].rst(in_operator=True)}",
    "Delete": lambda op: f"delete <sp> {op.operands[0].rst(in_operator=True)}[{op.operands[1].rst(in_operator=True, markup=False)}]",
    "Deref": unary("*"),
    "DecrPostfix": unary("", "--"),
    "DecrPrefix": unary("++"),
    "Difference": binary("-"),
    "DifferenceAssign": binary("-="),
    "Division": binary("/"),
    "DivisionAssign": binary("/="),
    "Equal": binary("=="),
    "Greater": binary(">"),
    "GreaterEqual": binary(">="),
    "In": binary("in"),
    # Operator generated here; named so it is sorted after `In`.
    "InInv": binary("!in"),
    "HasMember": binary("?."),
    "TryMember": binary(".?"),
    "Member": binary("."),
    "Index": lambda op: f"{op.operands[0].rst(in_operator=True)}[{op.operands[1].rst(in_operator=True, markup=False)}]",
    "IndexAssign": lambda op: f"{op.operands[0].rst(in_operator=True)}[{op.operands[1].rst(in_operator=True, markup=False)}] = {op.operands[2].rst(in_operator=True, markup=False)}",
    "IncrPostfix": unary("", "++"),
    "IncrPrefix": unary("++"),
    "LogicalAnd": binary("&&"),
    "LogicalOr": binary("||"),
    "Lower": binary("<"),
    "LowerEqual": binary("<="),
    "Modulo": binary("%"),
    "Multiple": binary("*"),
    "MultipleAssign": binary("*="),
    "Negate": unary("~"),
    "New": keyword("new"),
    "Pack": keyword("pack"),
    "Power": binary("**"),
    "Unpack": keyword("unpack"),
    "Unset": lambda op: f"unset <sp> {op.operands[0].rst(in_operator=True)}.{op.operands[1].rst(in_operator=True, markup=False)}",
    "SignNeg": unary("-"),
    "Size": unary("|", "|"),
    "ShiftLeft": binary("<<"),
    "ShiftRight": binary(">>"),
    "Sum": binary("+"),
    "SumAssign": binary("+="),
    "Unequal": binary("!="),
}

NamespaceMappings = {
    "signed_integer": "integer",
    "unsigned_integer": "integer",
    "struct_": "struct",
}

TypeMappings = {
    "::hilti::rt::regexp::MatchState": "spicy::MatchState",
    "::hilti::rt::bytes::Side": "spicy::Side",
}

LibraryType = re.compile(r'__library_type\("(.*)"\)')
TypedType: Pattern = re.compile(r"type<(.*)>")


def namespace(ns):
    return NamespaceMappings.get(ns, ns)


def rstHeading(title, level):
    return "{}\n{}\n".format(title, "==-~"[level] * len(title))
    return f"{title}\n{'==-~'[level] * len(title)}\n"


def fmtDoc(doc):
    n = []
    doc = doc.split("\n\n")
    for i in doc:
        x = textwrap.dedent(i).strip()
        wrapped = textwrap.indent(textwrap.fill(x), prefix="    ")
        if wrapped:
            n += [wrapped]

    return "\n\n".join(n)


def fmtType(ty):
    ty = LibraryType.sub("\\1", ty)
    ty = TypeMappings.get(ty, ty)

    if not ty:
        ty = "<no-type>"

    if ty == "any":
        return "<any>"

    ty = ty.replace("<*>", "")
    ty = ty.replace("const ", "")
    ty = ty.replace("hilti::", "spicy::")
    ty = re.sub("\\s+{\\s+}", "", ty)  # e.g., "enum { }" -> "enum"
    # e.g., type<enum> -> "enum-type"
    ty = re.sub("type<([^>0-9]+)>", "\\1-type", ty)
    return ty.replace(" ", "~")


class Operand:
    def __init__(self, m):
        self.const = m.get("const")
        self.kind = m.get("kind")
        self.default = m.get("default")
        self.id = m.get("id")
        self.optional = m.get("optional")
        self.doc = m.get("doc")
        self.type = m.get("type")

    def rst(self, in_operator=False, prefix="", markup=True):
        if self.doc:
            type = fmtType(self.doc)
        else:
            type = fmtType(self.type)

        if not in_operator:
            default = f" = {self.default}" if self.default else ""
            x = f"{self.id}: {type}{default}".strip()
        else:
            if markup:
                x = f"t:{type}"
            else:
                x = f"{type}"

        x = f"{prefix}{x}"
        return f"[ {x} ]" if self.optional else x


class Operator:
    def __init__(self, m):
        self.doc = m.get("doc")
        self.kind = m.get("kind")
        self.namespace = namespace(m.get("namespace"))
        self.operands = [Operand(i) for i in m.get("operands")]
        self.operator = m.get("operator")
        self.rtype = m.get("rtype")
        self.commutative = m.get("commutative")

    def rst(self):
        try:
            sig = Operators[self.kind](self)
        except KeyError:
            print(
                f"error: operator {self.kind} not supported by spicy-doc-to-rst yet",
                file=sys.stderr,
            )
            sys.exit(1)

        result = fmtType(self.rtype)
        return (
            ".. spicy:operator:: "
            f"{self.namespace}::{self.kind} {result} {sig}\n\n{fmtDoc(self.doc)}"
        )

    def __lt__(self, other):
        # Sort by string representation to make sure
        # the rendered output has a fixed order.
        return self.rst() < other.rst()


class Method:
    def __init__(self, m):
        self.args = [Operand(i) for i in m.get("args")]
        self.doc = m.get("doc")
        self.id = m.get("id")
        self.kind = m.get("kind")
        self.namespace = namespace(m.get("namespace"))
        self.rtype = m.get("rtype")
        self.self = Operand(m.get("self"))

    def rst(self):
        def arg(a):
            if a.kind == "in":
                qual = ""
            else:
                qual = a.kind + " "

            return a.rst(prefix=qual)

        args = ", ".join([arg(a) for a in self.args])
        const = self.self.const == "const"
        self_ = fmtType(self.self.type)
        result = fmtType(self.rtype)
        sig = (
            ".. spicy:method:: "
            f"{self.namespace}::{self.id} {self_} {self.id} {const} {result} ({args})\n\n{fmtDoc(self.doc)}"
        )
        return sig

    def __lt__(self, other):
        # Sort by string representation to make sure
        # the rendered output has a fixed order.
        return self.rst() < other.rst()


# Manages a single output file.
class OutputFile:
    def __init__(self, ns, args):
        self.fd = sys.stdout
        self.fname = None
        self.fname_tmp = None

        if args.dir:
            self.fname = namespace(ns)

            if self.fname.endswith("_"):
                self.fname = self.fname[:-1]

            self.fname = (
                self.fname.lower().replace("::", "-").replace("_", "-") + ".rst"
            )
            self.fname = os.path.join(args.dir, self.fname)
            self.fname_tmp = self.fname + ".tmp"
            self.fd = open(self.fname_tmp, "w+", encoding="utf-8")
        else:
            self.fd = sys.stdout

    def close(self):
        if not self.fname:
            return

        assert self.fname_tmp
        self.fd.close()

        if not os.path.exists(self.fname) or not filecmp.cmp(
            self.fname, self.fname_tmp
        ):
            os.rename(self.fname_tmp, self.fname)
        else:
            os.unlink(self.fname_tmp)


# Main


parser = argparse.ArgumentParser(
    description="Converts the output of spicy-doc on stdin into reST"
)
parser.add_argument(
    "-d",
    action="store",
    dest="dir",
    metavar="DIR",
    help="create output for all types in given directory",
)
parser.add_argument(
    "-t",
    action="store",
    dest="types",
    metavar="TYPES",
    help="create output for specified, comma-separated types;"
    "without -d, output goes to stdout",
)
args = parser.parse_args()

if not args.dir and not args.types:
    print("need -t <type> or -d <dir>", file=sys.stderr)
    sys.exit(1)

try:
    meta = json.load(sys.stdin)
except ValueError as e:
    fatalError(f"cannot parse input: {e}")

operators: dict[str, list[Operator]] = {}
methods: dict[str, list[Method]] = {}

for op in meta:
    if op["kind"] == "MemberCall":
        m1 = Method(op)
        x1 = methods.setdefault(m1.namespace, [])
        x1 += [m1]
    else:
        m2 = Operator(op)
        x2 = operators.setdefault(m2.namespace, [])
        x2 += [m2]

        # If we have a `in` operator automatically generate docs for `!in`.
        if m2.kind == "In":
            m3 = copy.copy(m2)
            m3.kind = "InInv"
            m3.doc = "Performs the inverse of the corresponding ``in`` operation."
            x2 = operators.setdefault(m3.namespace, [])
            x2 += [m3]

keys = set()

if args.dir:
    keys = set(operators.keys()) | set(methods.keys())
    try:
        os.makedirs(args.dir)
    except OSError:
        pass

if args.types:
    for k in args.types.split(","):
        for i in operators.keys() | methods.keys():
            if i.startswith(k):
                keys.add(i)

for ns in sorted(keys):
    prefix = ""

    if "::view" in ns:
        prefix = "View "

    if "::iterator" in ns:
        prefix = "Iterator "

    # In the following we remove duplicate entries where multiple items end up
    # turning into the same reST rendering. This happens when the difference is
    # only in typing issues that we don't track in the documentation. An
    # example is the vector's index operators for constant and non-constant
    # instances, respectively. Other duplications are coming from joining
    # namespaces for integers.
    already_recorded: set[str] = set()

    def print_unique(out, s):
        if s not in already_recorded:
            print(s + "\n", file=out.fd)
            already_recorded.add(s)

    out = OutputFile(ns, args)

    x1 = sorted(methods.get(ns, []))
    if x1:
        print(f".. rubric:: {prefix}Methods\n", file=out.fd)

        for method in sorted(x1):
            print_unique(out, method.rst())

    x2 = sorted(operators.get(ns, []))

    if x2:
        print(f".. rubric:: {prefix}Operators\n", file=out.fd)

        for operator in sorted(x2):
            print_unique(out, operator.rst())

            # Special case generic operators: write them into individual per-kind files as well.
            if operator.namespace == "generic" and out.fd != sys.stdout:
                out2 = OutputFile(f"{ns}_{operator.kind.lower()}", args)
                print(operator.rst() + "\n", file=out2.fd)
                out2.close()

    out.close()
