import types
import value, ast, lexer, error

def wrap_val(pyval):
    if isinstance(pyval, value.W_Value):
        return pyval
    if isinstance(pyval, list) or isinstance(pyval, tuple):
        return value.W_ValueTuple([wrap_val(l) for l in pyval])
    if isinstance(pyval, bool):
        return value.W_ValueBoolean(pyval)
    if isinstance(pyval, int):
        return value.W_ValueInt(pyval)
    if isinstance(pyval, float):
        return value.W_ValueFloat(pyval)
    if isinstance(pyval, basestring):
        return value.W_ValueString(pyval)
    if isinstance(pyval, dict):
        return value.W_ValueDict(dict([
            (wrap_val(k), wrap_val(v)) for (k,v) in pyval.items()]))
    if pyval is None:
        return value.W_VALUE_NONE
    raise error.CannotWrap(pyval)

class UnitVarDefn(object):
    """ used by wrap_func """
    def __init__(self, name, bounds_ast):
        assert isinstance(bounds_ast, list)
        assert all(isinstance(b, ast.ExprUnitEq) for b in bounds_ast)
        self.name = name
        self.bounds_ast = bounds_ast

    def use(self, arg_pattern):
        return UnitVarUse(self, arg_pattern)

class UnitVarUse(object):
    def __init__(self, var_defn, arg_pattern):
        self.var_defn = var_defn
        self.arg_pattern = arg_pattern

def wrap_func(arg_patternss, wrap_return=True):
    """
    uvar_bounds: a dict mapping unit var names to a list of
    ast.ExprUnitEq nodes representing the super types of the unit var,
    or [] for a var that can match anything (no bounds).

    units: a list with the unit constraints for each parameter.  The
    list should be the same length as the number of arguments for the
    python function.  A unit constraint can be a string, in which
    case it is the name of a unit variable, or None, in which case there
    must be no unit.

    data_types: a list of lists.  Each inner list should contain the
    names of the valid data types for each argument (a string), or
    None to match anything.

    See stdlib.py for examples.
    """
    
    loc = lexer.DUMMY_LOC

    def wrap(func):
        funcs = []

        uvar_bounds = {}

        func_loc = lexer.FileLoc(func.func_code.co_filename,
                                 func.func_code.co_firstlineno)

        if wrap_return:
            def exec_func(*args):
                res = func(*args)
                return wrap_val(res)
            exec_func.__doc__ = func.__doc__
        else:
            exec_func = func

        for arg_patterns in arg_patternss:
            def make_ast_pattern(arg_pattern):
                if arg_pattern is None:
                    return None
                elif isinstance(arg_pattern, value.W_ValueMultiFunction):
                    expr_const = ast.ExprConstant(loc, arg_pattern)
                    return ast.PatternApply.new(loc=loc,
                                                func=expr_const,
                                                args=[])
                elif isinstance(arg_pattern, UnitVarUse):
                    var_defn = arg_pattern.var_defn
                    uvar_bounds[var_defn.name] = var_defn.bounds_ast
                    uvar_ast = ast.UnitVar(loc, var_defn.name)
                    inner_pat = make_ast_pattern(arg_pattern.arg_pattern)
                    return ast.PatternApply.new(loc=loc,
                                                func=uvar_ast,
                                                args=[inner_pat])
                else:
                    NotImplemented # user error, invalid wrap() call
            args = [ast.Arg(loc, idx, make_ast_pattern(arg_pattern))
                    for idx, arg_pattern in enumerate(arg_patterns)]
            funcs.append(value.PythonFunction(
                func_loc, uvar_bounds, args, exec_func))

        nm = func.__doc__.split()[0]
        assert nm
        multif = value.W_ValueMultiFunction(funcs)
        multif.expose_as = nm
        return multif
    
    return wrap

def wrap_extend(multif1):
    """
    Used to additional cases with a different Python function that has
    already been wrapped in a value.W_ValueMultiFunction.  See stdlib.py
    for examples.
    """
    def wrap(multif2):
        assert multif1.expose_as == multif2.expose_as
        for f in multif2.funcs:
            multif1.add_func(f)
        return multif1
    return wrap
