import sys, types
from spark import GenericParser
import lexer

def debug(msg, *args):
    #print msg % args
    pass

# ___________________________________________________________________________
# Ast: this is the tree that results from parsing.  It's fairly generic.

from cStringIO import StringIO

def indent(str, indent="  "):
    assert str[-1] == '\n'
    return indent + str[:-1].replace('\n', '\n'+indent) + '\n'

def dump(item):
    if isinstance(item, Ast):
        return item.dump()
    elif isinstance(item, list):
        return "\n".join([dump(i) for i in item])
    return repr(item)

class Ast(object):
    def __init__(self, type, loc, **kwargs):
        self.type = type
        self.loc = loc
        self.ast_keys = set()
        self.other_keys = set(['type'])
        for k, v in kwargs.items():
            self.set(k, v)

    def set(self, key, value):
        if isinstance(value, Ast) or isinstance(value, list):
            self.ast_keys.add(key)
        else:
            self.other_keys.add(key)
        if isinstance(value, lexer.Token):
            value = value.val
        setattr(self, key, value)

    def get(self, key):
        return getattr(self, key)

    def __repr__(self):
        return 'Ast(%s) @ %s' % (
            ",".join(["%s=%r" % (k, self.get(k),) for k in self.other_keys]),
            self.loc)

    def descendants(self, not_of_type=[]):
        for _, nodes in self.children():
            for node in nodes:
                if node.type in not_of_type: continue
                yield node
                for d in node.descendants(not_of_type=not_of_type):
                    yield d

    def children(self):
        for k in self.ast_keys:
            v = self.get(k)
            if not isinstance(v, list): v = [v]
            yield (k, v)

    def tree_repr(self):
        res = StringIO()
        res.write('* %r:\n' % self)
        for nm, val in self.children():
            res.write('  %s:\n' % nm)
            reprs = [indent(n.tree_repr(), "  ") for n in val]
            res.write("".join(reprs))
        return res.getvalue()

class Visitor(object):
    def proc(self, ast):
        if isinstance(ast, list):
            return self.proc_all(ast)
        
        if ast:
            if hasattr(self, ast.type):
                return getattr(self, ast.type)(ast)
            else:
                return self.default(ast)
        return None

    def default(self, ast):
        raise NotImplemented

    def proc_all(self, asts, **kwargs):
        return [self.proc(a, **kwargs) for a in asts]

# ___________________________________________________________________________
# Parser Rules Metadata

# Conventions:
#
#   Nonterminals are the keys of rules.
#
#   If the value is a list, then we choose args[-1] from any of the strings
#   listed.
#
#   If the value is a dictionary, then the keys are the productions
#   and the values are actions.  Actions can be:
#   * a dictionary: create an Ast object with the type given by the
#   non-terminal, and the fields taken from the dictionary (whose values
#   are processed recursively).
#   * an integer n: convert to args[n].
#
#   Suffixes:
#   * 0 or more
#   + 1 or more
#   , comma-separated list
#
#   Values:
#   [...]: convert to arguments, where n refers to args[n]
#   n: args[n]
#

def normalize(rules):
    rules = normalize_lists(rules)
    rules = normalize_actions(rules)
    rules = normalize_star(rules)
    rules = normalize_comma(rules)
    rules = normalize_optional(rules)
    return rules

def normalize_lists(rules):
    "Finds all items with list values and makes them dicts."
    new_rules = {}
    for nonterm, prods in rules.items():
        if isinstance(prods, list):
            res = {}
            for item in prods:
                cnt = len(item.split())
                res[item] = cnt-1
            new_rules[nonterm] = res
        else:
            new_rules[nonterm] = prods
    return new_rules

def loc(val):
    if isinstance(val, Ast):
        return val.loc
    elif isinstance(val, lexer.Token):
        return val.loc
    elif isinstance(val, list):
        if val:
            return loc(val[0])
    return None

def normalize_actions(rules):
    "Converts all actions from dicts or numbers etc into a lambda"
    def action_func(nm, initial_action):            
        def func(self, args):
            def interpret(args, action):
                if action is None or isinstance(action, basestring):
                    return action
                elif isinstance(action, bool):
                    return action
                elif isinstance(action, dict):
                    res = Ast(nm, loc(args[0]) if args else None)
                    for key, value in action.items():
                        res.set(key, interpret(args, value))
                    return res                
                elif isinstance(action, int):
                    return args[action]
                elif isinstance(action, list) or isinstance(action, tuple):
                    return [interpret(args, a) for a in action]
                elif isinstance(action, types.FunctionType):
                    return action(args)
                assert False, "Uknown action: %r" % (action,)
                raise SystemExit # XXX undefined action
            debug("%s: interpret(%r, %r)", nm, args, action)
            i = interpret(args, initial_action)
            debug("  yields %r", i)
            return i
        return func
    new_rules = {}
    for nonterm, prods in rules.items():
        new_prods = {}
        for prod, action in prods.items():
            new_prods[prod] = action_func(nonterm, action)
        new_rules[nonterm] = new_prods
    return new_rules

def suffixed(rules, suffix):
    suffixed = set()
    for nonterm, prods in rules.items():
        assert isinstance(prods, dict)
        for prod, action in prods.items():
            for item in prod.split():
                if item != suffix and item[-1] == suffix:
                    suffixed.add(item[:-1])
    return suffixed

def normalize_star(rules):
    "Finds all references to items with star suffix and expands them."
    star_suffixed = suffixed(rules, '*')

    for item in star_suffixed:
        # For each item foo*, define 2 rules:
        #
        # foo* ::=              []
        # foo* ::= foo foo*     [0] + 1
        def emptylist(self, args):
            return []
        def additem(self, args):
            debug("additem(%r,%r)", args[0], args[-1])
            return [args[0]] + args[-1]
        item_star = item+"*"
        rules[item_star] = {
            '': emptylist,
            '%s %s' % (item, item_star): additem,
            }
        
    return rules

def normalize_comma(rules):
    "Finds all references to items with comma suffix and expands them."
    suffixed_comma = suffixed(rules, ',')

    for item in suffixed_comma:
        # For each item foo,, define 3 rules:
        #
        # foo, ::=              []
        # foo, ::= foo          [0]
        # foo, ::= foo , foo,   [0] + 2
        def emptylist(self, args):
            return []
        def lonelyitem(self, args):
            return [args[0]] 
        def additem(self, args):
            return [args[0]] + args[-1]
        item_comma = item+","
        rules[item_comma] = {
            '': emptylist,
            '%s' % item: lonelyitem,
            '%s , %s' % (item, item_comma): additem,
            }
        
    return rules

def normalize_optional(rules):
    "Finds all references to items with question-mark suffix and expands them."
    for item in suffixed(rules, '?'):
        # For each item foo?, define 2 rules:
        #
        # foo? ::=              None
        # foo? ::= foo          0
        def none(self, args):
            return None
        def some(self, args):
            return args[0]
        item_qmark = item+"?"
        rules[item_qmark] = {
            '': none,
            item: some,
            }
        
    return rules

def dump_rules(rules):
    for nonterm, prods in rules.items():
        for prod, action in prods.items():
            print "%s ::= %s" % (nonterm, prod)

rules = dict(
    Program={ 'Decl*':dict(decls=0) },

    Decl=[ 'ImportDecl', 'FuncDecl', 'UnitDecl' ],

    ImportDecl={
    'import FullId nl':dict(id=1, open_ended=False),
    'import FullId : nl':dict(id=1, open_ended=True),
    },
    
    FuncDecl={
    'def FullId ( Arg, ) : Suite':
    dict(id=1, unit_args=[], args=3, body=6),
    'def FullId [ UnitArg, ] ( Arg, ) : Suite':
    dict(id=1, unit_args=3, args=6, body=9),
    },

    UnitArg={
    'id':dict(name=0, super_units=[]),
    'id ( Unit, )':dict(name=0, super_units=2),
    },

    Arg={
    'id':dict(name=0, pat=None),
    'id = Pattern':dict(name=0, pat=2),
    '= Pattern':dict(name=None, pat=1),
    'PatternSyntax':dict(name=None, pat=0),
    '( Arg )':1,
    },

    PatternSyntax={
    #'Arg ExprId Arg':dict(type='PatternApply', func=1, args=[0,2]),
    '[ Arg, ]':dict(type='PatternTuple', args=1),
    '[ PatternItem, ]':dict(type='PatternDict', items=1),
    'Number':dict(type='PatternExpectedValue', value=0),
    'String':dict(type='PatternExpectedValue', value=0),
    'PatternFunc ( Arg, )':dict(type='PatternApply',
                                func=0, args=2),
    },
    
    PatternItem={
    'id = Arg':dict(name=0, arg=2),
    },
    
    Pattern={
    'PatternFunc':dict(type='PatternApply', decl=False, func=0, args=[]),
    'PatternSyntax':0,
    },

    PatternFunc={
    'ExprId':dict(func=0),
    'ExprUnitEq':dict(func=0),
    },

    UnitDecl={
    'unit id nl':dict(name=1, super_units=[], defin=None),
    'unit id ( UnitEq, ) nl':dict(name=1, super_units=3, defin=None),
    'UnitDef':0,
    },

    UnitDef={
    'unit id = UnitEq nl':dict(name=1, defn=3),
    },
    
    Suite={"nl indent Stmt* dedent":dict(stmts=2)},
    Stmt={
    "Stmtnl nl":0,
    "FuncDecl":dict(type='StmtFuncDecl', decl=0),
    },
    Stmtnl={
    'var Arg := Expr':dict(type='StmtDeclVar', lhs=1, rhs=3),
    'Arg := Expr':dict(type='StmtAssign', lhs=0, rhs=2),
    'Expr := Expr':dict(type='StmtAssign', lhs=0, rhs=2),
    'return Expr':dict(type='StmtReturn', expr=1),
    'IfStmt':0,
    'while ( Expr ) : Suite':dict(type='StmtWhile', cond=2, body=5),
    'Expr':0,
    },

    IfStmt={
    'if ( Expr ) : Suite ElseIf+ else : Suite':
    dict(type="If", cond=2, then=5, elsifs=6, els=8),
    'if ( Expr ) : Suite ElseIf+':
    dict(type="If", cond=2, then=5, elsifs=6, els=None),
    'if ( Expr ) : Suite':
    dict(type="If", cond=2, then=5, elsifs=[], els=None),
    },

    ElseIf={
    'else if ( Expr ) : Suite':[3, 6]
    },

    Expr={
    '( Expr )': 1,
    '( Arg, | Expr )': dict(type='ExprLambda', args=1, body=3),
    '[ Expr, ]': dict(type='ExprTuple', items=1),
    '[ DictItem, ]': dict(type='ExprDict', items=1),
    'ExprId': 0,
    'Number': 0,
    'String': 0,
    'Expr . ExprPost': dict(type='ExprApply', func=2, args=[0]),
    'ExprUnitEq': 0,
    'Expr ( Expr, )': dict(type='ExprApply', func=0, args=2),
    'Expr [ Expr, ]': dict(type='ExprIndex', obj=0, args=2),
    'Expr ExprId Expr': dict(type='ExprApply', func=1, args=[0,2]),
    },

    DictItem={
    'Expr = Expr': dict(key=0, value=2),
    },

    ExprPost=[
    'ExprId',
    'ExprUnitEq'
    ],

    ExprUnitEq={
    '{ UnitFac* }':dict(num=0, denom=[]),
    '{ UnitFac* / UnitFac* }':dict(num=0, denom=2),
    },

    UnitFac={
    'UnitId':0,
    'UnitId ^ PosInt':dict(type='UnitRaised', unit=0, n=2),
    'UnitId ^ ':dict(type='UnitRaisedInf', unit=0),
    },

    UnitId={
    'FullId':dict(id=0),
    },
    
    ExprId={
    'FullId':dict(id=0),
    'OpId':dict(id=0),
    },

    OpId={
    'AnyOp':dict(type='Id', ns=None, id=0),
    'FullIdBase AnyOp':dict(type='Id', ns=0, id=1),
    },

    AnyOp=['op', '/'],

    FullId={
    'id':dict(type='Id', ns=None, id=0),
    'FullIdBase id':dict(type='Id', ns=0, id=1),
    },

    FullIdBase={
    ':':dict(type='Id', ns=None, id=""),
    'id :':dict(type='Id', ns=None, id=0),
    'FullIdBase id :':dict(type='Id', ns=0, id=1),
    },

    Number={
    'PosInt':0,
    'float':dict(type='Float', num=0),
    },

    PosInt={
    'pos_int':dict(int=0),
    },
    
    String={
    'string':dict(type='String', value=0),
    },
    )
rules = normalize(rules)

# ___________________________________________________________________________
# Parser Class

class MultiParser(GenericParser):
    def __init__(self):
        GenericParser.__init__(self, 'Program')
        
    def error(self, token):
        sys.stderr.write("%s: Syntax error at or near %s\n" % (
            token.loc, token))
        raise SystemExit

    def typestring(self, token):
        return token.type

def install_rules(rules):
    for nonterm, prods in rules.items():
        for idx, (prod, action) in enumerate(prods.items()):
            func_name = 'p_%s_%d' % (nonterm, idx)
            assert action.__name__[0] != 'p'
            action.__name__ = func_name
            action.__doc__ = "%s ::= %s" % (nonterm, prod)
            setattr(MultiParser, func_name, action)
install_rules(rules)

def parse(filenm, text=None):
    tokens = lexer.tokenize(filenm, text)
    return MultiParser().parse(tokens)

if __name__ == "__main__":
    dump_rules(rules)
    for filenm in sys.argv[1:]:
        ast = parse(filenm)
        print ast.tree_repr()

