"""
Some python functions which constitute the "standard library".

Conventions:

A global named XYZ with a type of value.Unit a unit :std:XYZ.

A global named XYZ with a type of value.PythonFunction defines a
function :std:XYZ.  Use the decorators in 'wrap' to create those from
Python functions.
"""
from wrap import wrap_func, wrap_val, wrap_extend, UnitVarDefn
import value, ast, error
from lexer import DUMMY_LOC

# ___________________________________________________________________________
# Standard Units

w_arith = value.unit_atom_w("arith")
w_option = value.unit_atom_w("option")
w_some = value.unit_atom_w("some")
w_none = value.unit_atom_w("none")

# ___________________________________________________________________________
# Standard Pattern Matching Funcs

@wrap_func([[None]])
def w_int(scope, w_v):
    "int"
    return isinstance(w_v, value.W_ValueInt)

@wrap_func([[None]])
def w_float(scope, w_v):
    "float"
    return isinstance(w_v, value.W_ValueFloat)

@wrap_func([[None]])
def w_string(scope, w_v):
    "string"
    return isinstance(w_v, value.W_ValueString)

@wrap_func([[None]])
def w_dict(scope, w_v):
    "dict"
    return isinstance(w_v, value.W_ValueDict)

@wrap_func([[None]])
def w_tuple(scope, w_v):
    "tuple"
    return isinstance(w_v, value.W_ValueTuple)

@wrap_func([[None]])
def w_func(scope, w_v):
    "func"
    return isinstance(w_v, value.W_ValueMultiFunction)

@wrap_func([[None]])
def w_unit(scope, w_v):
    "unit"
    return isinstance(w_v, value.W_Unit)

# ___________________________________________________________________________
# Stuff for Wrapping Functions

# arith^:
ast_arith_r = ast.UnitRaisedInf(None, ast.UnitFixed(None, w_arith))
# arith^/arith^:
ast_any_arith = ast.ExprUnitEq(None, [ast_arith_r], [ast_arith_r])

a1 = UnitVarDefn('a1', [ast_any_arith])
a2 = UnitVarDefn('a2', [ast_any_arith])

int_float = [[w_int, w_int],
             [w_int, w_float],
             [w_float, w_float],
             [w_float, w_int]]

str_list = [[w_string, w_string],
            [w_tuple, w_tuple]]

def combine(arith_vars, arg_typess):
    def helper(arg_types):
        for arg_type, arith_var in zip(arg_types, arith_vars):
            yield arith_var.use(arg_type)
    return [list(helper(arg_types)) for arg_types in arg_typess]

# ___________________________________________________________________________
# Standard Functions
#
# These Python functions will be converted via the @wrap_func
# decorator converts them into value.W_ValueMultiFunction instances.
# The first word of the docstring is the name the function is exposed
# as.  @wrap_extend() is used to merge two Python functions.

# Add any two items of same (or compat) types with no unit
@wrap_func(int_float+str_list)
def w_add(scope, w_l, w_r):
    "+"
    return w_l.value + w_r.value

# Add any two items of same (or compat) types with the same arithmetic
# unit and propagate the unit
@wrap_extend(w_add)
@wrap_func(combine([a1, a1], int_float+str_list))
def w_add(scope, w_l, w_r):
    "+"
    assert w_l.w_unit == w_r.w_unit
    w_res = w_l.value + w_r.value
    return w_res.with_unit(w_l.w_unit)

# Subtract numeric items with no unit
@wrap_func(int_float)
def w_sub(scope, w_l, w_r):
    "-"
    return w_l.value - w_r.value

# Subtract two numeric items with same arithmetic unit 
@wrap_extend(w_sub)
@wrap_func(combine([a1, a1], int_float))
def w_sub(scope, w_l, w_r):
    "-"
    assert w_l.w_unit == w_r.w_unit
    w_res = wrap_val(w_l.value - w_r.value)
    return w_res.with_unit(w_l.w_unit)

# Multiply numeric items with no unit
@wrap_func(int_float)
def w_mul(scope, w_l, w_r):
    "*"
    w_res = wrap_val(w_l.value * w_r.value)
    return w_res

# Multiply numeric items with arithmetic units
@wrap_extend(w_mul)
@wrap_func(combine([a1, a2], int_float))
def w_mul(scope, w_l, w_r):
    "*"
    w_res = wrap_val(w_l.value * w_r.value)
    return w_res.with_unit(w_l.w_unit * w_r.w_unit)

# Divide numeric items with no unit
@wrap_func(int_float)
def w_div(scope, w_l, w_r):
    "/"
    w_res = wrap_val(w_l.value / w_r.value)
    return w_res

# Divide numeric items with arithmetic units
@wrap_extend(w_div)
@wrap_func(combine([a1, a2], int_float))
def w_div(scope, w_l, w_r):
    "/"
    w_res = wrap_val(w_l.value / w_r.value)
    return w_res.with_unit(w_l.w_unit / w_r.w_unit)

# Print most values.
@wrap_func([[w_int], [w_float], [w_string], [w_tuple], [w_dict]])
def w_print(scope, w_arg):
    "print"
    if w_arg.w_unit == value.W_UNITLESS:
        print w_arg.value
    else:
        print w_arg.value, w_arg.w_unit

# apply: invoked when a the user tries to call a non-function.
# Default just raises an error, though a user could extend this
# to make some other kind of value callable.
@wrap_func([[None, w_tuple]])
def w_apply(scope, w_f, w_args):
    "apply"
    raise error.ApplicationToNonFunction(DUMMY_LOC, w_f, w_args)

# Applies a function to each element in the list, ignoring any
# result.  The name comes from O'Caml.
@wrap_func([[None, w_tuple]])
def w_iter(scope, w_f, w_lst):
    "iter"
    for w_item in w_lst.value:
        w_f.apply_to(DUMMY_LOC, [w_item])

# Extracts an item from a tuple by index.
@wrap_func([[w_tuple, w_int]])
def w_index(scope, w_lst, w_idx):
    "[]"
    return w_lst.value[w_idx.value]

