"""

Machine-Independent Instruction Definitions: See codegen.py

It can be somewhat confusing what is supposed to be exported
by this file, so here is an explanation.

First, all the classes (Operand, CGInstruction, CodeGenTemplate,
InstrDefinition) are public.

Second, the two enumerations (optypes, opflags) are public.

Finally, the variable defn, which is an instance of InstrDefinition,
is public and serves as the basis for all the machine-specific
instruction definitions.

"""

import sys

import opcode

import instrs, types, chelpers, util, library

from lowir import \
     Operation, \
     TempArgument, \
     AddrArgument, \
     MemArgument, \
     ValueArgument, \
     ConstArgument, \
     ConstantValueArgument, \
     LengthValueArgument, \
     ModuleArgument, \
     InstrDefinition, \
     CodeGenTemplate, \
     tartypes

     
# Public global variable:
defn = InstrDefinition ()

# Shorthand:
temparg = TempArgument
addrarg = AddrArgument
memarg = MemArgument
constarg = ConstArgument
valarg = ValueArgument
cvalarg = ConstantValueArgument
lvalarg = LengthValueArgument
LIB = tartypes.LIB
LABEL = tartypes.LABEL
BLOCKEXIT = tartypes.BLOCKEXIT
VALUE = tartypes.VALUE

# Helper routines to make instruction definition readable and concise:
curtype = None
curtmpl = None
def start_instr_type (instrtype):
    global curtype, curtmpl

    curtmpl = CodeGenTemplate ()
    try:
        defn.defs[instrtype].append (curtmpl)
    except KeyError:
        defn.defs[instrtype] = [curtmpl]
        pass
    curtype = instrtype
    pass

def set_condition (condition):
    curtmpl.condfunc = condition
    return

def oper (*operargs, **operkw):
    return curtmpl.emit (Operation (*operargs, **operkw))

# Useful array containing the different classes that instructions can be:
instrclasses = filter (lambda x: type (x) is types.ClassType,
                       instrs.__dict__.values())

# Helpers for common instruction sequences used in the definitions of
# many Instruction objects:
def handle_exception ():
    
    """ A helper that generates the standard set of instructions for
    handling an exception.  The code corresponds to the following C
    code::

       Pynto_ErrFetch (&exctb, &exctyp, &excval);
       goto EXCEPT;

    Where ``exc*`` represent the parameters to the instruction and
    EXCEPT represents the basic block connected to the terminal
    'EXCEPT'.

    We use Pynto_ErrFetch because if exctb, exctyp, or excval are not
    used then the AddrOperand will substitute a NULL value.  In that
    case Pynto_ErrFetch handles it correctly, whereas PyErr_Fetch
    would crash.  In this case correctly means creating a temporary
    and releasing the reference later. """
    
    exctb = valarg (curtype.EXCTB)
    exctyp = valarg (curtype.EXCTYP)
    excval = valarg (curtype.EXCVAL)
    oper ('CALL',
          sources=(addrarg (exctb), addrarg (exctyp), addrarg (excval)),
          target=(LIB,library.errfetch))
    oper ('BR', target=(BLOCKEXIT,curtype.EXCEPT))
    pass

def inc_ref (op):

    """ A helper that generates the instructions to increment the
    ref count on an object """

    rc = memarg (op, chelpers.ref_count())
    oper ('ADD',
          dests=(rc,),
          sources=(rc, constarg (1)),
          comment='Inc ref count')
    pass

def dec_ref (op):

    """ A helper that generates the instructions to decrement the
    ref count on an object; if the refcount drops to zero, calls
    Pynto_Free() to release the memory """

    rc = memarg (op, chelpers.ref_count())
    oper ('SUB', dests=(rc,), sources=(rc, constarg (1)))
    oper ('BNE', sources=(rc, constarg (0)), target=(LABEL,'decref_notfreed'))
    oper ('CALL', sources=(op,), target=(LIB,library.freeobject))
    oper ('NOOP', label='decref_notfreed')
    pass

def load_attr_slow (obj, attrname, tar):

    """ A helper which generates code to load a named attribute from an
    object.  This calls the helper and is not fast. """
    
    retval = temparg ()
    oper ('CALL',
          dests=(retval,),
          sources=(obj, attrname),
          target=(LIB,library.getattr))
    oper ('BNE',
          sources=(retval, constarg (0)),
          target=(LABEL,'okay'),
          comment='Test for exception')
    handle_exception()
    oper ('MOVE',
          dests=(tar,),
          sources=(retval,),
          label='okay',
          comment='No exception, so load return value')
    inc_ref (tar)
    return

def store_attr_slow (obj, attrname, val):
    retval = temparg ()
    oper ('CALL',
          dests=(retval,),
          sources=(obj,
                   attrname,
                   val),
          target=(LIB,library.setattr))
    oper ('BEQ', sources=(retval, constarg (0)),
          target=(BLOCKEXIT,curtype.FALLTHROUGH))
    handle_exception()
    return

# ======================================================================
# Actual Instruction Definitions

# Unary Instructions
# ==================

# Define the unary instructions as calls to the appropriate function
# (such as PyNumber_Pos):
for instrtype in filter (lambda x: issubclass(x, instrs.UnaryInstruction) and \
                         x is not instrs.UnaryInstruction, instrclasses):
    start_instr_type (instrtype)
    dest = valarg (curtype.OUTPUT)
    src = valarg (curtype.INPUT)
    temp = temparg ()

    # Call the function and check for a NULL return value
    oper ('CALL', dests=(temp,), sources=(src,),
          target=(LIB,library.LibraryWrapper(curtype.funcname)))
    oper ('BNE', sources=(temp, constarg (0)), target=(LABEL,'okay'))
    
    # If we get here, the func returned NULL... that's an exception
    handle_exception ()

    # If we get here, the func returned ok.  Move the result into the
    # output and just fall through.  Note that these functions already
    # return a new reference, so we don't have to do anything about
    # that.
    oper ('MOVE', dests=(dest,), sources=(temp,), label='okay')
    pass

# Binary and Inplace Instructions
# ===============================

# BinaryCompareInstructions
# ===============================

cmp_ops = {}
idx = 0
for op in opcode.cmp_op:
    cmp_ops[op] = idx
    idx += 1
    pass

def opindex (opstr): return cmp_ops[opstr]

for copstr in ('<', '<=', '==', '!=', '>', '>='):

    cop = opindex (copstr)

    # Note that we must copy into a default value the current value of
    # cop because otherwise it would the condition function would use
    # the current value of cop for all of the testing functions.
    # Variables are scoped by reference... not value!
    
    start_instr_type (instrs.BinaryCompareInstruction)
    set_condition (lambda ins, copval=cop: ins.compare_op == copval)
    temp = temparg ()

    oper ('CALL',
           dests=(temp,),
           sources=(valarg (curtype.LEFT),
                    valarg (curtype.RIGHT),
                    constarg (cop)),
           target=(LIB, library.richcomp))
    oper ('BNE', sources=(temp, constarg (0)), target=(LABEL,'okay'))
    handle_exception ()
    oper ('MOVE',
          dests=(valarg (curtype.OUTPUT),),
          sources=(temp,),
          label='okay')
    
    pass

# Handle the in operator

cop = opindex ('in')
start_instr_type (instrs.BinaryCompareInstruction)
set_condition (lambda ins, copval=cop: ins.compare_op == copval)
dest = valarg (curtype.OUTPUT)

oper ('CALL',
       dests=(temp,),
       sources=(valarg (curtype.LEFT), valarg (curtype.RIGHT)),
       target=(LIB, library.contains))
oper ('BNE', sources=(temp, constarg (0)), target=(LABEL,'okay'))
handle_exception ()
oper ('MOVE', dests=(valarg (curtype.OUTPUT),), sources=(temp,), label='okay')

# Handle the is operator

start_instr_type (instrs.BinaryCompareInstruction)
set_condition (lambda ins, copval=opindex("is"): ins.compare_op == copval)
dest = valarg (curtype.OUTPUT)

oper ('BEQ',
      sources=(valarg (curtype.LEFT), valarg (curtype.RIGHT)),
      target=(LABEL,'okay'))
oper ('CALL', dests=(dest,), target=(LIB,library.falsehood))
oper ('BR', target=(BLOCKEXIT,curtype.FALLTHROUGH))
oper ('CALL', dests=(dest,), target=(LIB,library.truth), label='okay')

# Handle the is not operator

start_instr_type (instrs.BinaryCompareInstruction)
set_condition (lambda ins, copval=opindex("is not"): ins.compare_op == copval)
dest = valarg (curtype.OUTPUT)

oper ('BNE',
      sources=(valarg (curtype.LEFT), valarg (curtype.RIGHT)),
      target=(LIB,'okay'))
oper ('CALL', dests=(dest,), target=(LIB,library.falsehood))
oper ('BR', target=(BLOCKEXIT,curtype.FALLTHROUGH))
oper ('CALL', dests=(dest,), target=(LIB,library.truth), label='okay')

# Define the binary instructions as calls to the appropriate function
# (such as PyNumber_Add):
for instrtype in filter (lambda x: issubclass(x, instrs.BinaryInstruction) and
                         x is not instrs.BinaryInstruction, instrclasses):

    # compares are handled differently:
    if instrtype is instrs.BinaryCompareInstruction or \
           instrtype is instrs.InplaceInstruction or \
           instrtype is instrs.InplaceArithmeticInstruction or \
           instrtype is instrs.BinaryArithmeticInstruction:
        continue

    start_instr_type (instrtype)
    dest = valarg (curtype.OUTPUT)
    left = valarg (curtype.LEFT)
    right = valarg (curtype.RIGHT)
    temp = temparg ()

    # Call the function and check for a NULL return value
    oper ('CALL',
          dests=(temp,),
          sources=(left, right),
          target=(LIB,library.LibraryWrapper(curtype.funcname)))
    oper ('BNE', sources=(temp, constarg (0)), target=(LABEL,'okay'))

    # If we get here, the func returned NULL... that's an exception
    handle_exception ()

    # If we get here, the func returned ok.  Move the result into the
    # output and just fall through.  Note that these functions already
    # return a new reference, so we don't have to do anything about
    # that.
    oper ('MOVE', dests=(dest,), sources=(temp,), label='okay')
    pass

# CallInstruction
# ===============

# knode FUNC target.
# Note that knode targets are the
# only targets that are constants
# -------------------------------

start_instr_type (instrs.CallInstruction)
set_condition (lambda ins: ins.get_first_value (instrs.CallInstruction.FUNC).is_constant())
temp = temparg ()
oper ('CALL',
      dests=(temp,),
      sources=(valarg (curtype.ARGS),),
      target=(VALUE, curtype.FUNC))
oper ('BNE', sources=(temp, constarg (0)), target=(LABEL,'okay'))
handle_exception()
oper ('MOVE',
      dests=(valarg (curtype.TARGET),),
      sources=(temp,),
      label='okay')

# Temporary Operand Target
# ------------------------

start_instr_type (instrs.CallInstruction)
argtuple = temparg ()
temp = temparg ()

# Build the tuple with the arguments.  Note that the C helper function
# takes a variable number of arguments after the length.
oper ('CALL',
      dests=(argtuple,),
      sources=(lvalarg (curtype.ARGS), valarg (curtype.ARGS)),
      target=(LIB,library.buildtuple))

# Now call the standard python function for calling a callable object.
# Note that we pass NULL as the keyword arguments for now.
oper ('CALL',
      dests=(temp,),
      sources=(valarg (curtype.FUNC), argtuple, constarg (0)),
      target=(LIB,library.callobject))

# After we use the argtuple, release the memory associated with it
dec_ref (argtuple)
       
# Check for and handle any exceptions.
oper ('BNE', sources=(temp, constarg (0)), target=(LABEL,'okay'))
handle_exception()
oper ('MOVE', dests=(valarg (curtype.TARGET),), sources=(temp,), label='okay')

# BuildTupleInstruction
# =====================

# This helper function takes a variable number of arguments prefixed
# by the count.
start_instr_type (instrs.BuildTupleInstruction)
oper ('CALL',
      dests=(valarg (curtype.TARGET),),
      sources=(lvalarg (curtype.SOURCE), valarg (curtype.SOURCE)),
      target=(LIB,library.buildtuple))

# BuildInstanceInstruction
# =====================

# This helper function takes a variable number of arguments terminated by
# a NULL pointer.
start_instr_type (instrs.BuildInstanceInstruction)
# TODO

# StoreAttrInstruction
# ====================

# TODO --- faster version that directly derefences the pointer
# ------------------------------------------------------------

# Generic Version; use the PyObject_SetAttr() function,
# which returns 0 on success, or -1 on exception
# -----------------------------------------------------------

start_instr_type (instrs.StoreAttrInstruction)
store_attr_slow (valarg (curtype.PTR),
                 valarg (curtype.ATTRNAME),
                 valarg (curtype.SOURCE))

# LoadAttrInstruction
# ===================

# TODO --- faster version that directly derefences the pointer
# ------------------------------------------------------------

# Generic Version; use the PyObject_GetAttr() function,
# which returns 0 on success, or -1 on exception
# -----------------------------------------------------------

start_instr_type (instrs.LoadAttrInstruction)
load_attr_slow (valarg (curtype.PTR),
                valarg (curtype.ATTRNAME),
                valarg (curtype.TARGET))

# MoveInstruction
# ===============

start_instr_type (instrs.MoveInstruction)
tar = valarg (curtype.TARGET)
oper ('MOVE', dests=(tar,), sources=(valarg (curtype.SOURCE),))
inc_ref (tar)

# JumpIfTrueInstruction
# =====================

start_instr_type (instrs.JumpIfTrueInstruction)
retval = temparg ()
oper ('CALL',
      dests=(retval,),
      sources=(valarg (curtype.TEST),),
      target=(LIB,library.objistrue))
oper ('BEQ', sources=(retval, constarg (0)),
      target=(BLOCKEXIT,curtype.FALLTHROUGH))
oper ('BEQ', sources=(retval, constarg (1)),
      target=(BLOCKEXIT,curtype.BRANCH))
handle_exception()

# JumpIfFalseInstruction
# =====================

start_instr_type (instrs.JumpIfFalseInstruction)
retval = temparg ()
oper ('CALL',
      dests=(retval,),
      sources=(valarg (curtype.TEST),),
      target=(LIB,library.objistrue))
oper ('BEQ', sources=(retval, constarg (0)),
      target=(BLOCKEXIT,curtype.BRANCH))
oper ('BEQ', sources=(retval, constarg (1)),
      target=(BLOCKEXIT,curtype.FALLTHROUGH))
handle_exception()

# LoadParamInstruction
# ====================

start_instr_type (instrs.LoadParamInstruction)
tar = valarg (curtype.TARGET)
oper ('LOADPARAM', dests=(tar,), sources=(cvalarg (curtype.INDEX),))
inc_ref (tar)

# ReturnInstruction: return the value specified
# =============================================

start_instr_type (instrs.ReturnInstruction)
oper ('RET', sources=(valarg (curtype.SOURCE),))

# AbortInstruction: just returns NULL
# ===================================

start_instr_type (instrs.AbortInstruction)
oper ('RET', sources=(constarg (0),))

# LoadGlobalInstruction
# =====================

start_instr_type (instrs.LoadGlobalInstruction)
load_attr_slow (ModuleArgument (),
                valarg (curtype.NAME),
                valarg (curtype.TARGET))

# StoreGlobalInstruction
# ======================

start_instr_type (instrs.StoreGlobalInstruction)
retval = temparg ()
oper ('CALL',
      dests=(retval,),
      sources=(ModuleArgument (),
               valarg (curtype.NAME),
               valarg (curtype.SOURCE)),
      target=(LIB,library.setattr))

# ForIterInstruction
# ==================

start_instr_type (instrs.ForIterInstruction)
# TODO

# DecRefInstruction: implemented with helper func
# ===============================================

start_instr_type (instrs.DecRefInstruction)
dec_ref (valarg (curtype.TARGET))
