# SPDX-License-Identifier: Apache-2.0

"""
https://gist.github.com/wataash/dc9f8352dc147b223382def86c3a592b

A Python implementation of "Packrat Parsers Can Support Left Recursion" by
Alessandro Warth, James R. Douglass, and Todd Millstein.
http://www.vpri.org/pdf/tr2007002_packrat.pdf

3.2 Supporting Direct Left Recursion
"""

import collections
import dataclasses
import inspect
import re
import typing as t


def n_stack() -> int:
    return len([None for x in inspect.stack() if x.function in ['EVAL', 'GROW_LR']]) - 1


"""
[R.body]:
In the APPLY-RULE() function in the paper, the EVAL() function is invoked as:

    EVAL(R.body)

where R is RULE. But I have no idea what the R.body (RULE's body) is --
instead, we define the global variable "BODY", and invoke EVAL() function as:

    EVAL(R)
"""
Body = t.NewType('Body', str)  # not defined in the paper
# BODY = Body('1234-5')

POS = t.NewType('POS', int)  # POS
# Pos = POS(0)  # Pos


@dataclasses.dataclass(frozen=True)
class AST:
    """
    AST
    The data structure is not defined in the paper.
    """
    type: str
    string: str
    leaf_len: int
    children: list['AST']
    n_stack: int

    def __repr__(self):
        color = ['\x1b[31m', '\x1b[32m', '\x1b[33m'][self.n_stack % 3]
        if self.children:
            children = " ".join(str(x) for x in self.children)
            return f'{color}<\x1b[0m{self.type}:{self.string} {children}{color}>\x1b[0m'
        return f'{color}<\x1b[0m{self.type}:{self.string}{color}>\x1b[0m'


FAIL = AST('FAIL', '', 0, [], 0)  # FAIL


@dataclasses.dataclass(frozen=False)
class LR:
    """LR"""
    detected: bool


@dataclasses.dataclass(frozen=False)
class MemoEntry:
    """MEMOENTRY"""
    ans: t.Union[AST, LR]
    pos: POS


# keep order for better debug-print()
memo_: collections.OrderedDict[tuple['RULE', POS], MemoEntry] = collections.OrderedDict()


# noinspection PyPep8Naming
def MEMO(R: 'RULE', P: POS) -> t.Optional[MemoEntry]:
    global memo_
    """MEMO :(RULE, POS) -> MEMOENTRY"""
    return memo_.get((R, P))


# noinspection PyPep8Naming
def MEMO_set(R: 'RULE', P: POS, m: MemoEntry):
    global memo_
    """MEMO(R, P) <- m"""
    assert (R, P) not in memo_
    memo_[R, P] = m


memo_str_last: dict[tuple['RULE', POS], str] = {}


def print_memo_if_changed(R: 'RULE', P: POS, indent: str):
    global memo_str_last
    s = f'{indent}  MEMO({R.__name__}, {P}) <- {MEMO(R, P)}'
    if s == memo_str_last.get((R, P)):
        return
    print(s)
    memo_str_last[R, P] = s


# def EVAL(body: Body) -> AST:  # [R.body]
# noinspection PyPep8Naming
def EVAL(R: 'RULE') -> AST:
    """
    EVAL
    """
    global BODY, Pos
    pos_orig = Pos
    indent = '  ' * n_stack()
    print(f'\x1b[34m{indent}{R.__name__}: evaluate\x1b[0m')
    print(f'{indent}  {Pos} {BODY[Pos:]}')
    ast = R()
    if ast is FAIL:
        assert Pos == pos_orig
        print(f'\x1b[34m{indent}{R.__name__}: failed\x1b[0m')
        print(f'{indent}  {Pos} {BODY[Pos:]}')
        return FAIL
    pos2 = Pos + ast.leaf_len
    print(f'\x1b[34m{indent}{R.__name__}: matched: {BODY[pos_orig:pos2]}\x1b[0m')
    if Pos != pos2:
        print(f'{indent}  Pos:{Pos}->{pos2} BODY:{BODY[Pos:pos2]}>{BODY[pos2:]}')
    print(f'{indent}  ast:{ast}')
    Pos = pos2
    return ast


# noinspection PyPep8Naming
def APPLY_RULE(R: 'RULE', P: POS) -> AST:
    """
    Figure 4.
    APPLY-RULE(R, P)
    """
    global Pos
    indent = '  ' * n_stack()
    m = MEMO(R, P)
    if m is None:
        lr = LR(False)
        m = MemoEntry(lr, P)
        MEMO_set(R, P, m)
        print_memo_if_changed(R, P, indent)
        # ans = EVAL(R.body)  # [R.body]
        ans = EVAL(R)  # [R.body]
        m.ans = ans
        m.pos = Pos
        print_memo_if_changed(R, P, indent)
        if lr.detected and ans is not FAIL:
            print(f'\x1b[34m{indent}  GROW-LR\x1b[0m')
            return GROW_LR(R, P, m, None)
        else:
            return ans
    else:
        print(f'\x1b[34m{indent}{R.__name__}: using memo\x1b[0m')
        print(f'{indent}  Pos:{Pos}->{m.pos} BODY:{BODY[Pos:m.pos]}>{BODY[m.pos:]}')
        print(f'{indent}  m:{m}')
        Pos = m.pos
        if isinstance(m.ans, LR):
            m.ans.detected = True
            print_memo_if_changed(R, P, indent)
            return FAIL
        return m.ans


# noinspection PyPep8Naming
def APPLY_RULE_or_rollback_Pos(R: 'RULE', P: POS, pos_rollback_to: POS) -> AST:
    global Pos
    ast = APPLY_RULE(R, P)
    if ast is FAIL:
        Pos = pos_rollback_to
    return ast


# noinspection PyPep8Naming
def GROW_LR(R: 'RULE', P: POS, M: MemoEntry, H: 'unknown') -> AST:
    """
    Figure 3.
    GROW-LR
    """
    global Pos
    # ...  # line A
    while True:
        Pos = P
        # ...  # line B
        # ans = EVAL(R.body)  # [R.body]
        ans = EVAL(R)  # [R.body]
        if ans is FAIL or Pos <= M.pos:
            break
        M.ans = ans
        M.pos = Pos
    # ...  # line C
    Pos = M.pos
    return M.ans


# -----------------------------------------------------------------------------
# RULE

# The type is defined in the paper
RULE = t.Callable[[], AST]

rule_expr: RULE
rule_num: RULE
rule_plus: RULE
rule_minus: RULE


def rule_term() -> AST:
    """
    term ::= <term> "+" <fact>
           / <term> "-" <fact>
           / <fact>
    """

    def rule1() -> AST:
        """<term> "+" <fact>"""
        global BODY, Pos
        pos_orig = Pos
        ast0 = APPLY_RULE_or_rollback_Pos(rule_term, Pos, pos_orig)
        if ast0 == FAIL:
            return FAIL
        ast1 = APPLY_RULE_or_rollback_Pos(rule_plus, Pos, pos_orig)
        if ast1 == FAIL:
            return FAIL
        ast2 = APPLY_RULE_or_rollback_Pos(rule_fact, Pos, pos_orig)
        if ast2 == FAIL:
            return FAIL
        return AST('term', BODY[pos_orig:Pos], 0, [ast0, ast1, ast2], n_stack())

    def rule2() -> AST:
        """<term> "-" <fact>"""
        global BODY, Pos
        pos_orig = Pos
        ast0 = APPLY_RULE_or_rollback_Pos(rule_term, Pos, pos_orig)
        if ast0 == FAIL:
            return FAIL
        ast1 = APPLY_RULE_or_rollback_Pos(rule_minus, Pos, pos_orig)
        if ast1 == FAIL:
            return FAIL
        ast2 = APPLY_RULE_or_rollback_Pos(rule_fact, Pos, pos_orig)
        if ast2 == FAIL:
            return FAIL
        return AST('term', BODY[pos_orig:Pos], 0, [ast0, ast1, ast2], n_stack())

    def rule3() -> AST:
        """<num>"""
        global BODY, Pos
        pos_orig = Pos
        ast0 = APPLY_RULE_or_rollback_Pos(rule_fact, Pos, pos_orig)
        if ast0 == FAIL:
            return FAIL
        return AST('term', BODY[pos_orig:Pos], 0, [ast0], n_stack())

    global BODY, Pos
    ast = rule1()
    if ast != FAIL:
        return ast
    ast = rule2()
    if ast != FAIL:
        return ast
    ast = rule3()
    if ast != FAIL:
        return ast
    return FAIL


def rule_fact() -> AST:
    """
    fact ::= <fact> "*" <num>
           / <fact> "/" <num>
           / <num>
    """

    def rule1() -> AST:
        """<fact> "*" <num>"""
        global BODY, Pos
        pos_orig = Pos
        ast0 = APPLY_RULE_or_rollback_Pos(rule_fact, Pos, pos_orig)
        if ast0 == FAIL:
            return FAIL
        ast1 = APPLY_RULE_or_rollback_Pos(rule_mul, Pos, pos_orig)
        if ast1 == FAIL:
            return FAIL
        ast2 = APPLY_RULE_or_rollback_Pos(rule_num, Pos, pos_orig)
        if ast2 == FAIL:
            return FAIL
        return AST('fact', BODY[pos_orig:Pos], 0, [ast0, ast1, ast2], n_stack())

    def rule2() -> AST:
        """<fact> "/" <num>"""
        global BODY, Pos
        pos_orig = Pos
        ast0 = APPLY_RULE_or_rollback_Pos(rule_fact, Pos, pos_orig)
        if ast0 == FAIL:
            return FAIL
        ast1 = APPLY_RULE_or_rollback_Pos(rule_div, Pos, pos_orig)
        if ast1 == FAIL:
            return FAIL
        ast2 = APPLY_RULE_or_rollback_Pos(rule_num, Pos, pos_orig)
        if ast2 == FAIL:
            return FAIL
        return AST('fact', BODY[pos_orig:Pos], 0, [ast0, ast1, ast2], n_stack())

    def rule3() -> AST:
        """<num>"""
        global BODY, Pos
        pos_orig = Pos
        ast0 = APPLY_RULE_or_rollback_Pos(rule_num, Pos, pos_orig)
        if ast0 == FAIL:
            return FAIL
        return AST('fact', BODY[pos_orig:Pos], 0, [ast0], n_stack())

    global BODY, Pos
    ast = rule1()
    if ast != FAIL:
        return ast
    ast = rule2()
    if ast != FAIL:
        return ast
    ast = rule3()
    if ast != FAIL:
        return ast
    return FAIL


def rule_num() -> AST:
    """
    <num>
    """
    global BODY, Pos
    if Pos >= len(BODY):
        return FAIL
    match = re.search(r'^(\d+)', BODY[Pos:])
    if match is None:
        return FAIL
    return AST('<num>', match[1], len(match[1]), [], n_stack())


def rule_plus() -> AST:
    """
    "+"
    """
    global BODY, Pos
    if Pos >= len(BODY):
        return FAIL
    if BODY[Pos] != '+':
        return FAIL
    return AST('"+"', '+', len('+'), [], n_stack())


def rule_minus() -> AST:
    """
    "-"
    """
    global BODY, Pos
    if Pos >= len(BODY):
        return FAIL
    if BODY[Pos] != '-':
        return FAIL
    return AST('"-"', '-', len('-'), [], n_stack())


def rule_mul() -> AST:
    """
    "*"
    """
    global BODY, Pos
    if Pos >= len(BODY):
        return FAIL
    if BODY[Pos] != '*':
        return FAIL
    return AST('"*"', '*', len('*'), [], n_stack())


def rule_div() -> AST:
    """
    "/"
    """
    global BODY, Pos
    if Pos >= len(BODY):
        return FAIL
    if BODY[Pos] != '/':
        return FAIL
    return AST('"/"', '/', len('/'), [], n_stack())


# -----------------------------------------------------------------------------
# main

if __name__ == '__main__':
    BODY = Body('1-2/3')
    Pos = POS(0)  # Pos
    pos_orig = Pos
    ast = APPLY_RULE_or_rollback_Pos(rule_term, Pos, pos_orig)
    print()
    print(ast)