"""

Refinement
==========

The Refinement pass actually performs a number of local, standard
optimizations.  See docs/design.txt for more information.

"""

import sys
import types

import knodes
import values
import instrs
import dump
import util
import blocks

def refine_ir (comp, funcnode, jitted):

    indent = comp.log.indent ('Refinement')

    refiner = IRRefiner (comp, funcnode, jitted)

    refiner.seed_worklist (None)
    refiner.strip_unreachable (None)
    refiner.drain_worklist ()
    pass

class IRRefiner:

    ELIMINATE = 0
    REPLACE   = 1
    
    def __init__ (self, comp, funcnode, jitted):
        self.jitted = jitted
        self.startblk = jitted.get_start_block()
        self.stopblk = jitted.get_stop_block()
        self.comp = comp
        self.log = comp.log
        self.funcnode = funcnode
        self.root = comp.get_root()
        self.module = funcnode.get_module_node ()
        self.visit1 = comp.get_visit_tag()
        self.visit2 = comp.get_visit_tag()
        self.worklist = util.TransientWorkSet ((instrs.Instruction,
                                                blocks.BasicBlock))

        # Adding the module as an attribute of the refiner object
        # allows it to be accessed via refine_methods without directly
        # importing the instrs module.  This is important because the
        # instrs module imports refine_methods!
        self.instrs = instrs

        #dump.dump_ir_non_canon (sys.stderr,
        #                        0,
        #                        self.comp,
        #                        self.startblk,
        #                        self.stopblk)
        return

    def seed_worklist (self, blk):
        if not blk: blk = self.startblk

        # use the while loop to achieve tail recursion
        while blk:
            # Avoid visting the same block more than once.
            if blk.visited (self.visit1): return
            blk.visit (self.visit1)

            # Initially attempt to refine every instruction; this may
            # add to the worklist.  Note that we need to copy the
            # list of instructions because self.refine might modify it
            for instr in blk.get_instrs():
                self.log.mid ('Adding %s to worklist' % instr)

                # Check that the operands that the instruction writes
                # are used and remove them if not
                for op in instr.get_targets ():
                    if len (op.get_value().get_uses()) == 1:
                        op.remove ()
                        pass
                    pass
                
                self.worklist += instr
                pass

            # Otherwise, continue to the block's successors.  Note
            # that we use the while loop to tail recurse.
            if blk is not self.stopblk:
                # If the block has no instructions, then it can be
                # removed!  Note that even if we remove the block, we
                # should still walk to its successors.
                if not blk.get_instrs():
                    blk = self.remove_block (blk)
                    continue
                            
                # recurse to process all the successors but the last one
                succs = blk.get_successors ()
                for sblk in succs[:-1]:
                    self.seed_worklist (sblk)
                    pass

                # tail recurse on the last successor
                blk = succs[-1]
                pass
            else:
                blk = None
                pass
            pass
            
        pass

    def strip_unreachable (self, blk):

        """ After the initial walk to seed the worklist, we do a reverse
        walk to strip any unreachable blocks from pred lists """

        if not blk: blk = self.stopblk

        while blk:
            if blk.visited (self.visit2): return
            blk.visit (self.visit2)

            indent = self.log.indent ("Unreachable walk: Block %s" % blk)

            if blk.get_instrs():
                self.log.mid ("First instruction: %s" % blk.get_instrs()[0])
                pass
            
            # Accumulate list of indicies of preds that were never visited in
            # the initial depth first search; these cannot be reached through
            # the code
            strip_blks = []
            idx = 0 
            for pblk, pidx in blk.get_predecessors():
                if not pblk.visited (self.visit1) and \
                       not pblk.visited (self.visit2):
                    strip_blks.append (idx)
                    self.log.mid ("Unreachable pred: %s" % pblk)
                    pass
                idx += 1
                pass
            
            if strip_blks:
                self.log.mid ("Stripping preds: %s" % repr(strip_blks))
                blk.strip_predecessors (strip_blks)
                for instr in blk.get_instrs ():
                    if isinstance (instr, instrs.PhiInstruction):
                        instr.strip_named (instr.SOURCE, strip_blks)
                        pass
                    pass
                pass

            del indent

            # Recurse on the pred blocks, except for the first one
            # for which we tail recurse
            if blk is not self.startblk:
                pblks = blk.get_predecessors ()
                for pblk, pidx in pblks[:-1]:
                    self.strip_unreachable (pblk)
                    pass
                blk = pblks[0][0]
            else:
                blk = None
                pass
            pass
        return
        
    def drain_worklist (self):
        for item in self.worklist: self.refine (item)
        pass

    def remove_block (self, blk):

        """ Removes a block from the program by unlinking it; note that
        it returns the successor """

        indent = self.log.indent ("Removing block %s" % blk)
        assert not blk.get_instrs()
        sidx = instrs.Instruction.FALLTHROUGH
        succ = blk.get_successors() [sidx]
        self.log.mid ("Succ blk: %s" % succ)

        # redirect our predecessors to our successor
        dupcnt = len (blk.get_predecessors()) # used below in PHI code
        for pblk, pidx in blk.get_predecessors():
            self.log.mid ("Pred blk %s:%d redirected" % (pblk, pidx))
            pblk.set_successor (pidx, succ)
            pass

        # Find our predecessor record in the successor, then remove us!
        # Remember what index we held in the pred list
        oldpidxtpl = (blk.set_successor (sidx, None),)

        # Go through the PHI nodes of the successor we are replacing
        # ourselves with and adjust their param list.  We need to
        # remove the entry that corresponded to us, and then add it back
        # again once for each pred of ours that we redirected
        for instr in succ.get_instrs():
            if not isinstance (instr, instrs.PhiInstruction): break
            val = instr.strip_named (instr.SOURCE, oldpidxtpl)
            self.log.low ("Adjusting PHI %s for value %s:",
                          instr, val)
            i = 0
            while i < dupcnt:
                instr.add_named (instr.SOURCE, val)
                i += 1
                pass
            pass

        # Update our start block if we just purged the first block in
        # the program
        if blk is self.startblk:
            self.startblk = succ
            self.jitted.set_start_block (succ)
            pass            

        return succ

    def refine (self, item):
        self.log.mid ('Refining item: %s' % item)

        # Call the refine method for the instruction.  It returns None,
        # or some sort of action.  The action to be taken is specified in
        # the form of a tuple; the first part of the tuple is an opcode.
        # Here are the legal tuples:
        #
        # (self.ELIMINATE, terminalidx, ( (oldvalue, newvalue)* ))
        #
        #    This eliminates the current instruction; if the instruction is
        #    the final instruction in the block, it connects the block's
        #    FALLTHROUGH terminal to the old terminal numbered ``terminalidx``.
        #    For each tuple in the second component, it replaces all uses
        #    of oldvalue with newvalue.
        #
        # (self.REPLACE, (newinstr*))
        #
        #    This removes the current instruction and replaces it with
        #    a sequence of instructions (newinstr*).  All of the instructions
        #    up to the last one must have only one successor, clearly.
        
        refres = item.refine (self)

        if not refres: return

        if refres[0] is self.ELIMINATE:
            self.handle_eliminate_code (item, refres[1], refres[2])
        elif refres[0] is self.REPLACE:
            self.handle_replace (item, refres[1])
        else:
            raise util.PyntoException ("Illegal refinment opcode %s" %
                                       refres[0])

        pass

    def handle_eliminate_code (self, instr, terminal, rvalues):
        indent = self.log.indent ('Elimination')

        self.log.low ('handle_eliminate_code instr=%s terminal=%s rvalues=%s',
                      instr, terminal, rvalues)
            
        # We can eliminate the instruction 'instr' from its
        # block.  If the instruction had multiple terminals, then
        # we should connect the block to what was previously located
        # at successor #terminal, and rvalues will be a list
        # of (old, new) tuples contained Values to replace.
        block = instr.get_block()
        islast = block.last_instr() is instr

        #before = self.log.indent ('uses-before')
        #for oldval, newval in rvalues:
        #    for use in oldval.get_uses():
        #        self.log.mid ('oldval=%s use=%s', (oldval, use))
        #        pass
        #    pass
        #del before

        # Unlink the instruction and remove it from the blocks list.
        self.log.mid ('Unlinking instruction %s (islast=%d)', instr, islast)
        srcvalues = instr.unlink ()
        block.get_instrs().remove (instr)

        #after = self.log.indent ('uses-after')
        #for oldval, newval in rvalues:
        #    for use in oldval.get_uses():
        #        self.log.mid ('oldval=%s use=%s', (oldval, use))
        #        pass
        #    pass
        #del after

        # Replace all uses of oldval with newval, and add those
        # instructions whose operands were modified to the worklist.
        #
        # Note that we must use safe_iter() as we are modifying the list
        # as we walk it.
        replacement = self.log.indent ('replacement')
        for oldval, newval in rvalues:
            assert isinstance (oldval, values.Value)
            assert isinstance (newval, values.Value)
            self.log.mid ('Replacing %s with %s' % (oldval, newval))
            for use in oldval.get_uses().safe_iter():
                self.log.low ('  Operand: %s', use)
                use.set_value (newval)
                useinstr = use.get_instruction ()
                self.log.mid ('Adding %s to worklist' % useinstr)
                assert instr is not useinstr
                self.worklist += useinstr
                pass
            #refine.note_removed_use (oldval)
            pass
        del replacement

        # Ensure the block is still connected properly
        if islast:
            block.shortcircuit (terminal)
            self.log.low ('Shortcircuited block %s, now connected to %s',
                          block, block.get_successors())
            pass

        # For each operand we used to read, note that we removed
        # a use.  This may uncover further dead code.  Note that
        # we must do this after the replacements have occurred because
        # otherwise we can get ourselves in weird situations in which
        # there are operands with no writes and the like. 
        for src in srcvalues:
            self.log.mid ("Noting removal of use from %s" % src)
            self.note_removed_use (src)
            pass
        
        # Now, we may have just emptied the block, in which case it
        # can be removed.
        if not block.get_instrs(): self.remove_block (block)
        pass

    def handle_replace (self, instr, newinstrs):
        indent = self.log.indent ("Replacement")

        # Find instr in the block's list of instructions and replace
        # it.
        block = instr.get_block()
        islast = block.last_instr() is instr
        instrs = block.get_instrs ()
        index = instrs.index (instr)
        if index == -1: raise util.PyntoException ("Instruction not in block")
        instrs[index:index+1] = newinstrs

        # TODO --- reroute successors etc when the number changes

        pass
    
    def clear_type_info (self, value):

        """ Called by the refine() methods of various items.  Clears the
        type info from value, and adds any affected instructions to the
        worklist. """
        
        assert isinstance (value, values.TemporaryValue)
        if value.get_types().clear():
            self.add_uses_to_worklist (value, instrs.DEREF)
            pass
        return

    def set_type_info (self, value, typeset):
        
        """ Called by the refine() methods of various items.  Sets the
        type info of value to typeset, and adds any affected
        instructions to the worklist. """
        
        assert isinstance (value, values.TemporaryValue)
        if value.get_types().copy_from(typeset):
            self.log.mid ('updated type of %s to %s' % (value, typeset))
            self.add_uses_to_worklist (value, instrs.DEREF)
            pass
        return

    def add_uses_to_worklist (self, value, flag):

        """ Adds all instructions that use the given value in the
        specified manner to the worklist.  flag indicates what kind of
        uses to add (reads, writes, or derefs) """

        indent = self.log.indent ('Adding uses of %s to worklist' % value)
        for use in value.get_uses ():
            if use.check_flag (flag):
                self.log.mid ('Adding %s' % use.get_instruction ())
                self.worklist += use.get_instruction ()
                pass
            pass
        return

    def note_removed_use (self, value):

        """ Called when a use is removed from the given value.  If
        that was the list place this value was used, it is removed
        from its generating instruction and the generating instruction
        is placed on the worklist. """

        if not value.is_temporary(): return

        numuses = len (value.get_uses())
        self.log.mid ("note_removed_use value=%s numuses=%d", value, numuses)

        if numuses == 1:
            instr = value.get_instr ()
            self.log.low("value=%s instr=%s use=%s" % (value,
                                                       instr,
                                                       value.get_uses()[0]))
            use = value.get_uses () [0]
            assert use.check_flag (instrs.WRITE)
            use.remove ()
            del use
            assert len (value.get_uses ()) == 0
            self.worklist += instr
            pass

        return

    pass
