import re
from cStringIO import StringIO
import lexer

# ___________________________________________________________________________
# This is the AST used during execution.  It is not the same as the AST
# produced by the parser.  It only represents the body of functions.

class Ast(object):
    def __init__(self, loc):
        self.loc = loc if loc else lexer.DUMMY_LOC

    @classmethod
    def from_parse_node(cls, procf, node):
        "Generic logic for moving fields from a parse node"
        res = cls(node.loc)
        for fld in cls.FIELDS:
            val = procf(getattr(node, fld))
            setattr(res, fld, val)
        return res

    @classmethod
    def new(cls, **dict):
        "Generic constructor"
        res = cls(None)
        for k, v in dict.items():
            setattr(res, k, v)
        return res

    def dump(self, out, indent=""):
        "Default multi-line dump is same as repr"
        out.write(indent)
        out.write(repr(self))
        
    def __repr__(self):
        " Default representation is based on the doc string "
        sep_chars = re.compile(r"\W")
        doc = self.__doc__
        if not doc: return "%s()" % (self.__class__.__name__,)
        res = StringIO()
        def proc_id(chnk):
            if hasattr(self, chnk):
                res.write(repr(getattr(self, chnk)))
            else:
                res.write(chnk)
        def proc_sep(chnk):
            res.write(chnk)
        pattern_str = doc.split('\n')[0]
        while pattern_str:
            mo = sep_chars.search(pattern_str)
            if not mo:
                proc_id(pattern_str)
                break
            if mo.start() != 0:
                proc_id(pattern_str[:mo.start()])
            proc_sep(mo.group(0))
            pattern_str = pattern_str[mo.end():]
        return res.getvalue()

# ___________________________________________________________________________

class FunctionAst(Ast):
    "def name[uarg_bounds](args): ..."
    def __init__(self, loc, name, uarg_bounds, args, body_stmt):
        Ast.__init__(self, loc)
        self.name = name
        self.uarg_bounds = uarg_bounds  # dict: str -> [value.Unit]
        self.args = args
        self.body_stmt = body_stmt
        
    def dump(self, out, indent=""):
        out.write('%sdef %s%r%r:\n' % (
            indent, self.name, self.uarg_bounds, self.args))
        self.body_stmt.dump(out, indent=indent)
        
# ___________________________________________________________________________

class Arg(Ast):
    "name=pattern"
    def __init__(self, loc, name, pattern):
        Ast.__init__(self, loc)
        self.name = name         # name of the var being assigned/decl'd
        self.pattern = pattern   # pattern value must match
        
class Pattern(Ast):
    pass
    
class PatternApply(Pattern):
    "func(args)"
    FIELDS = ["func", "args"]
    
class PatternTuple(Pattern):
    "[ args ]"
    FIELDS = ['args']

class PatternDict(Pattern):
    "[ items ]"
    FIELDS = ['items']
    
class PatternItem(Pattern):
    "w_key = arg"
    FIELDS = ['w_key', 'arg']
    
class PatternExpectedValue(Pattern):
    "w_value"
    def __init__(self, loc, w_value):
        Pattern.__init__(self, loc)
        self.w_value = w_value

# ___________________________________________________________________________

class Stmt(Ast):
    pass

class StmtSuite(Stmt):
    "stmts"
    def __init__(self, loc, stmts):
        Stmt.__init__(self, loc)
        self.stmts = stmts

    def dump(self, out, indent=""):
        indent += "  "
        for stmt in self.stmts:
            stmt.dump(out, indent)
            out.write("\n")

class StmtAssign(Stmt):
    "lhs := rhs"
    DECL_VAR = False
    FIELDS = ['lhs', 'rhs']

class StmtDeclVar(StmtAssign):
    "var lhs := rhs"
    DECL_VAR = True
        
class StmtIf(Stmt):
    "if (cond_expr) then_stmt else elst_stmt"
    def __init__(self, loc, cond_expr, then_stmt, else_stmt):
        Stmt.__init__(self, loc)
        self.cond_expr = cond_expr
        self.then_stmt = then_stmt
        self.else_stmt = else_stmt

    def dump(self, out, indent=""):
        out.write("if (%r):\n" % (self.cond_expr,))
        self.then_stmt.dump(out, indent)
        out.write("else:\n")
        self.else_stmt.dump(out, indent)

class StmtWhile(Stmt):
    "while (cond) body"
    FIELDS = ['cond', 'body']
    
    def dump(self, out, indent=""):
        out.write("while (%r):\n" % (self.cond,))
        self.body.dump(out, indent)

class StmtFor(Stmt):
    "for (var in list_expr) body_stmt"
    FIELDS = ['arg', 'expr', 'body']

class StmtReturn(Stmt):
    "return expr"
    FIELDS = ['expr']

class StmtNoop(Stmt):
    "noop"

# ___________________________________________________________________________

class Expr(Ast):
    pass

class ExprNone(Expr):
    "None"
    pass

class ExprLambda(Expr):
    """\
    ( | w_func )

    Creates a closure of a nested function.  Generated both for named
    and anonymous functions.
    """
    FIELDS = ['w_func']

class ExprTuple(Expr):
    "[ items ]"
    FIELDS = ['items']

class ExprLocalVar(Expr):
    "name"
    def __init__(self, loc, name):
        Expr.__init__(self, loc)
        self.name = name
        
class ExprDict(Expr):
    "{ key_exprs=value_exprs }"
    def __init__(self, loc, key_exprs, value_exprs):
        Expr.__init__(self, loc)
        self.key_exprs = key_exprs
        self.value_exprs = value_exprs

class ExprApply(Expr):
    "func (args)"
    FIELDS = ['func', 'args']

class ExprConstant(Expr):
    "w_value"
    def __init__(self, loc, w_value):
        Expr.__init__(self, loc)
        self.w_value = w_value
        
class ExprUnitEq(Expr):
    "num / denom"

    def __init__(self, pos, num, denom):
        # num and denom are lists of Unit ast nodes, mul'd together!
        Expr.__init__(self, pos)
        self.num = num
        self.denom = denom

    def invert(self):
        return ExprUnitEq(self.loc, self.denom, self.num)
# ___________________________________________________________________________

class Unit(Expr):
    pass
    
class UnitFixed(Unit):
    "w_unit"
    def __init__(self, loc, w_unit):
        Ast.__init__(self, loc)
        self.w_unit = w_unit  # a value.Unit

    def can_elim(self, unit):
        if isinstance(unit, UnitFixed) and unit.w_unit <= self.w_unit:
            return True
        return False

class UnitRaisedInf(Unit):
    "unit^"
    def __init__(self, loc, unit):
        Ast.__init__(self, loc)
        self.unit = unit # an Ast node

class UnitVar(Unit):
    "$name"
    def __init__(self, loc, name):
        Ast.__init__(self, loc)
        self.name = name

