"""

Pynto Code Generator
--------------------

See docs/codegen.txt for more info

"""

import util, blocks, instrs, cgraph, values
from instrs import PhiInstruction, LoadParamInstruction
from cStringIO import StringIO
from lowir import tartypes, Operation, LowLevelBlock
import ra_local
import chelpers, library

def generate_jitted (comp, modobj, md, funcnode, jitted):

    """ The main codegen routine.  When this routine completes, we
    will have generated actual assembly for the entire routine and it
    will be linked with the exception of any calls to other genreated
    code which will be incomplete as the other generated code may not
    yet have been, well, generated. """

    indent = comp.log.indent ('CodeGen %s' % jitted)

    # Create the FunctionInfo object which stores the state about the
    # function we are generating.
    finfo = md.create_function_info()
    jitted.set_function_info (finfo)

    # Generate the initial Low Level IR for the function: 
    emitter = generate_low_ir (comp, modobj, md, jitted)

    # This is a useful debugging point, after the code has been
    # generated but before resource allocation or other optimization
    # has occured:
    if comp.monitor: comp.monitor.low_level_emitter (funcnode, emitter)

    # Build a CFG, then destroy emitter to save memory:
    startblk = build_cfg (comp, md, jitted, emitter)
    del emitter
    if comp.monitor: comp.monitor.low_level_cfg (funcnode, startblk)

    # Perform register / stack slot allocation and code generation
    # (currently we hardcode the allocator.  I'd like this to be more
    # flexible in the future).
    if comp.monitor: comp.monitor.low_level_start_emitting (funcnode)
    ra_local.allocate_registers (startblk, finfo, comp)
    if comp.monitor: comp.monitor.low_level_end_emitting (funcnode)

    comp.log.high ("Loading generated code object")
        
    # Create the final Generated Code object
    gencode = finfo.get_generated_code ()

    # We're all done!  Set the generated code object on the JittedVersion;
    # this clears out any temporary data.
    jitted.set_generated_code (gencode)

    del indent

    return

# ----------------------------------------------------------------------

def generate_low_ir (comp, modobj, md, jitted):

    """ Generates the low IR for the function 'jitted' from Mid IR. """

    finfo = jitted.get_function_info ()
    emitter = md.create_emitter (finfo, modobj)
    cgid = 0

    def generate_block (vtag, block):

        block.visit (vtag)

        bbindent = comp.log.indent ('generate_block %s' % block)

        inslist = block.get_instrs ()

        def entrance_label (tarblock, predblock):
            """ Returns a label for the entrance point to tarblock when
            coming from predblock.  If tarblock has no PHI instructions,
            then the value of predblock is irrelevant. """
            if not tarblock.has_phis ():
                return "_%x" % (id (tarblock))
            return "_%x_%x" % (id (tarblock), id (predblock))

        # Process the Phi instructions (if any)
        if block.has_phis ():
            phiindent = comp.log.indent ('Phi Instructions Present') 

            idx = 0
            for pblk, pidx in block.get_predecessors ():
                firstinstr = True
                for instr in inslist:
                    # Stop after the last Phi
                    if not isinstance (instr, instrs.PhiInstruction): break

                    emitter.set_creation_instr (instr)

                    # Obtain the storage for the source:
                    srcval = instr.get_named (instr.SOURCE)[idx].get_value()
                    srcarg = emitter.translate_value (srcval)

                    # Obtain the storage for the target:
                    tarval = instr.get_named (instr.TARGET)[0].get_value ()
                    tararg = emitter.translate_value (tarval)

                    # Mark the instruction corresponding to the first
                    # Phi with the entrace_label.
                    if firstinstr: label = entrance_label (block, pblk)
                    else: label = None
                    firstinstr = False

                    comp.log.mid ('Label %s Index = %3d: %s (%s) <= %s (%s)',
                                  label, idx, tarval, tararg, srcval, srcarg)

                    emitter.emit (Operation ('MOVE',
                                             dests=(tararg,),
                                             sources=(srcarg,),
                                             label=label))
                    pass

                # After all the Phi code has been generated, do a branch
                # to the label 'MAIN'
                emitter.emit(Operation('BR', target=(tartypes.LABEL, 'MAIN')))

                idx += 1
                pass

            # Finally, we emit a NOOP with the 'MAIN' label that can
            # be branched to from the phi's
            emitter.emit (Operation ('NOOP', label='MAIN'))

            del phiindent
        else:
            # If there are no PHI instructions, we emit a NOOP with the
            # branch label that people can branch to.
            emitter.emit (Operation ('NOOP',
                                     label=entrance_label (block, None),
                                     comment='new block without phi'))
            pass

        # Now process the remainder
        notlastinstr = len (inslist)
        for instr in inslist:

            notlastinstr -= 1
            
            if isinstance (instr, instrs.PhiInstruction): continue

            emitter.set_creation_instr (instr)

            instrind = comp.log.indent ('Instruction %s' % instr)
            fallthroughtargeted = False

            # Select code generation template
            cgtmpl = md.select_cg_template (instr)

            # If we are tracing emit some interesting information about
            # the instruction we are executing
            if comp.trace_enabled and \
                   not isinstance (instr, LoadParamInstruction):
                instrid = comp.next_instruction_id ()
                descptr = chelpers.ptr (finfo.add_constant (
                    "About to execute #0x%08x: %s" % (instrid, instr)))
                emitter.emit (Operation (
                    'CALL',
                    sources=[md.get_integer_constant (instrid),
                             md.get_integer_constant (descptr),
                             md.get_integer_constant (0)],
                    comment='Trace',
                    target=(tartypes.LIB, library.trace)))

                if comp.trace_enabled > 1:
                    # Dump it's parameters too if in super-dooper-dump mode
                    for srcobj in instr.get_sources ():
                        paramptr = chelpers.ptr (finfo.add_constant (
                            "  Parameter %-10s: " %
                            instr.op_names[srcobj.name]))
                        emitter.emit (Operation (
                            'CALL',
                            sources=[md.get_integer_constant (instrid),
                                     md.get_integer_constant (paramptr),
                                     emitter.translate_value (srcobj.value)],
                            comment='Trace operand',
                            target=(tartypes.LIB, library.trace)))
                        pass
                    pass                            
                pass
            
            # Emit each of the instructions in the code gen template in turn
            for oper in cgtmpl.opers:

                comp.log.mid ('oper: %s' % oper)

                # Instantiate the arguments.  This converts
                # ValueArguments appropriately and that sort of thing.
                # Note that instantiate returns a list.
                def instantiate_args (arglist):
                    alistlen = len (arglist)
                    if not alistlen:
                        return ()
                    if alistlen == 1:
                        return arglist[0].instantiate (emitter, instr)
                    res = []
                    for d in arglist: res += d.instantiate (emitter, instr)
                    return res
                dests = instantiate_args (oper.dests)
                sources = instantiate_args (oper.sources)

                # Process the target if required and substitute a label for
                # any point w/in this function.
                if oper.tartype == tartypes.BLOCKEXIT:
                    tartype = tartypes.LABEL
                    if notlastinstr:
                        # If this is not the last instruction in the block,
                        # then we can only branch to fallthrough which falls
                        # through to the next instruction.  We simulate this
                        # by just inserting a NOOP at the end to target.
                        assert oper.target == instr.FALLTHROUGH
                        fallthroughtargeted = True
                        target = 'FALLTHROUGH'
                    else:
                        # If this IS the last instruction in the
                        # block, then we branch to another block.
                        tblk = block.get_successors() [oper.target]
                        target = entrance_label (tblk, block)
                        pass

                    comp.log.low ("BlockExit Target '%s' translated to '%s'",
                                  oper.target, target)
                                 
                    pass
                elif oper.tartype == tartypes.VALUE:
                    # This indicates we are calling a constant
                    # function.  We should have chosen a particular
                    # jitted version (which encodes the set of
                    # assumptions we make about the parameters etc)
                    value = instr.get_named (oper.target)[0].value
                    assert value.is_constant ()
                    knode = value.value
                    jversion = knode.select_jitted_version (comp, instr)
                    tartype = tartypes.GEN
                    target = jversion
                else:
                    tartype = oper.tartype
                    target = oper.target
                    pass

                # Emit the operation w/ the instantiated arguments.
                emitter.emit (Operation (oper.opcode,
                                         dests=dests,
                                         sources=sources,
                                         label=oper.label,
                                         comment=oper.comment,
                                         target=(tartype,target)))
                                          
                pass

            if fallthroughtargeted:
                emitter.emit (Operation ('NOOP', label='FALLTHROUGH'))
                pass
            
            del instrind

            pass

        succs = block.get_successors ()

        emitter.set_creation_instr (None)

        # Finally, after all the instructions, we should emit a BRanch
        # for the fallthrough to the next block in the list.  We only
        # need to do this if the fallthrough block won't be generated
        # next (i.e., it has already been generated).
        if succs:
            tblk = succs [0]
            if tblk.visited (vtag):
                target = entrance_label (tblk, block)
                emitter.emit (Operation ('BR',target=(tartypes.LABEL, target)))

                comp.log.mid ('Fallthrough to block that has been generated,'+
                              ' creating branch to label %s', target)
                
                pass
            pass
        else:
            # If there are no successors, then this is the last block,
            # and we need an EPILOGUE instruction.  There should only be
            # one block that fits this characteristic.
            #if comp.trace_enabled:
            #    emitter.emit (Operation (
            #        'CALL',
            #        sources=[],
            #        comment='Trace',
            #        target=(tartypes.LIB, library.trace_undent)))
            #    pass
            emitter.emit (Operation ('EPILOGUE'))
            pass

        # Continue the instruction definition depth first search:
        for succ in succs:
            if not succ.visited (vtag): generate_block (vtag, succ)
            pass

        return

    # The first thing we do is emit a PROLOGUE instruction
    emitter.emit (Operation ('PROLOGUE'))

    # If tracing, indent:
    #if comp.trace_enabled:
    #    descptr = chelpers.ptr (finfo.add_constant (
    #        "Starting %s" % jitted))
    #    emitter.emit (Operation (
    #        'CALL',
    #        sources=[md.get_integer_constant (descptr)],
    #        comment='Trace',
    #        target=(tartypes.LIB, library.trace_indent)))
    #    pass
            
    # Generate the initial low level IR for the function.
    generate_block (comp.get_visit_tag (), jitted.get_start_block ())

    return emitter

# ----------------------------------------------------------------------
# Construction of the Control Flow Graph
#
# The CFG is composed of low level blocks (lowir.LowLevelBlock), each of
# which has a set of operations.  Any unconditional branches are
# removed.

def build_cfg (comp, md, jitted, emitter):

    tostartfrom = util.WorkSet (Operation)

    cfgindent = comp.log.indent ('build_cfg')

    def start_block_from (oper):

        comp.log.high ('** start_block_from oper=%s', oper)

        # Create the block to receive the operations we walk now:
        curblk = LowLevelBlock ()
        comp.log.low ('  curblk created: %s', curblk)

        while oper:

            comp.log.mid ('  Considering Oper=%s' % oper)

            # This operation should not have been added to any other
            # block yet.
            if oper.block:
                assert oper.targeted > 1
                curblk.set_fallthrough (oper)
                comp.log.mid ('    Ending Block: this oper already in block')
                return curblk

            if oper.targeted > 1 and curblk.opers:
                # If this operation is targeted by another operation
                # besides the one that fell through to it, and it is
                # not the first in the block, then that ends this
                # basic block because there is more than one way to
                # get here: fallthrough from the last instruction, and
                # the branch from somewhere else.
                comp.log.mid ('    Ending Block: Operation targeted by %d',
                              oper.targeted)
                tostartfrom.append (oper)
                curblk.set_fallthrough (oper)
                return curblk
            
            elif oper.opcode == 'BR':
                # Otherwise, we handle unconditional branches differently.
                # If the target of the branch is only targeted from here,
                # we just strip the branch altogether.  Otherwise it is
                # laid down as the final operation in the basic block.

                nextoper = oper.target
                if nextoper.targeted == 1:
                    oper = nextoper
                    comp.log.low ('    Next operation only targeted by us: %s',
                                  nextoper)
                else:
                    comp.log.mid ('    Ending Block: Next oper %s has %d tars',
                                  nextoper, nextoper.targeted)
                    curblk.set_fallthrough (nextoper)
                    tostartfrom.append (nextoper)
                    return curblk
                pass

            elif oper.opcode == 'EPILOGUE':
                # This is the end of the function.  This block has no
                # successor.
                curblk.add_oper (oper)
                return curblk
            
            else:
                curblk.add_oper (oper)

                # If we optionally branch to another operation, then
                # that must be the end of this basic block.  Set the
                # fallthrough and return.
                if oper.tartype == tartypes.OPER:
                    comp.log.mid ('    Ending Block: branch %s, fallthru %s',
                                  oper.target, oper.em_next)
                    curblk.set_fallthrough (oper.em_next)
                    tostartfrom.append (oper.target)
                    tostartfrom.append (oper.em_next)
                    return curblk

                # Otherwise, move to the next operation and repeat
                # around the loop
                oper = oper.em_next
                comp.log.low ('    Continuing on to next instruction %s', oper)
                
                pass

            pass
        
        pass

    startblk = start_block_from (emitter.opers[0])
    for op in tostartfrom: start_block_from (op)

    del cfgindent
    
    return startblk

