from cStringIO import StringIO
import sys
import lowir, util

# Note: the constants are chosen specifically.  Everything is initialized
# to _U, but nothing is set to _U.  When something is set to _P
_F = 0  # Fail
_P = 1  # Pass
_U = 2  # Unknown

def debug (string, *arguments):
    sys.stderr.write (string % arguments)
    return

def _select_class (bits):
    if bits & lowir.argflags.REG: return lowir.SoftRegArgument
    if bits & lowir.argflags.MEM: return lowir.SoftStackArgument
    raise util.PyntoException ("Cannot wrap bits %x" % bits)

class _VariantState:

    def __init__ (self, variantinfo):
        self.variantinfo = variantinfo
        self.totalargs = len (variantinfo[0]) + len (variantinfo[1])
        pass

    def initial (self):
        self.totalpass = 0
        self.totalfail = 0
        self.disqualified = 0
        self.sources = [_U] * len (self.variantinfo[1])
        self.dests = [_U] * len (self.variantinfo[0])
        pass

    def clone (self):
        res = _VariantState (self.variantinfo)
        res.totalpass = self.totalpass
        res.totalfail = self.totalfail
        res.sources = list (self.sources)
        res.dests = list (self.dests)
        res.disqualified = self.disqualified
        return res

    def code_gen_func (self):
        return self.variantinfo[2]

    def set (self, argidx, value):

        """ Indicates that we applyed argument #argidx and it was found
        to match (if value == _P) or not to match (if value == _F).

        Note that argidx follows our wacky convention: positive numbers
        are sources, negatives are dests, and 0 is invalid. """

        assert value == _P or value == _F
        
        if argidx < 0:
            idx = -argidx-1
            if self.dests[idx] == value: return
            self.dests[idx] = value
        else:
            idx = argidx-1
            if self.sources[idx] == value: return
            self.sources[idx] = value
            pass

        if value == _P: self.totalpass += 1
        else:
            self.totalfail += 1

            # Watch out: we may need to set the disqualified flag if we
            # failed a apply a against a constant
            if (argidx > 0 and 
                self.get_apply_bits (argidx) & (lowir.argflags.CONST|
                                                lowir.argflags.LARGECONST)):
                self.disqualified = 1000 # set to large constant
                pass
            
            pass
                
        return

    def completely_matched (self):

        """ Returns true if we have completely applyed all possibilities and
        all were found to match. """
        
        return self.totalpass == self.totalargs

    def completely_applyed (self):

        """ Returns true if we have completely applyed all possibilities
        regardless of whether result was match or fail. """
        
        return self.totalpass + self.totalfail >= self.totalargs

    def failed_wrappers (self):

        """ Returns a list of tuples.  Each tuple has two values: an
        integer indicating the argument index (positive == source,
        negative == dest, offset by 1 so that 0 is invalid) and a
        class indicating an Argument class to instantiate and move the
        value into.  i.e., if the first source must be temporarily
        placed in a register, then (1, SoftRegArgument) will be
        returned. """

        results = []
        idx = 1
        for val, bits in zip (self.sources, self.variantinfo[1]):
            if val == _F:
                classfunc = _select_class (bits)
                results.append ( (idx, classfunc) )
                pass
            idx += 1
            pass
        idx = -1
        for val, bits in zip (self.dests, self.variantinfo[0]):
            if val == _F:
                classfunc = _select_class (bits)
                results.append ( (idx, classfunc) )
                pass
            idx -= 1
            pass
        return results

    def find_next_arg_index_to_apply (self):

        """ Assuming we have not completely applyed everything, finds
        and returns the next argument to apply as well as the bits to
        apply for.  Uses convention that
        positive indices indicate sources, negatives destinations, and
        0 is invalid. """
        
        idx = 1
        for val in self.sources:
            if val == _U: return (idx, self.variantinfo[1][idx-1])
            idx += 1
            pass
        idx = -1
        for val in self.dests:
            if val == _U: return (idx, self.variantinfo[0][-idx-1])
            idx -= 1
            pass
        raise util.PyntoException ("Nothing left to apply")

    def get_apply_bits (self, argidx):
        if argidx > 0: return self.variantinfo[1][argidx-1]
        return self.variantinfo[0][-argidx-1]

    def dump (self, out):
        return

    def detail (self):
        strs = { _U:" U", _F:" F", _P:" P" }
        
        name = StringIO ()
        name.write ("(")
        first = True
        for bits in self.variantinfo[0]:
            if not first: name.write (", ")
            name.write (lowir.argflags.string (bits))
            first = False
            pass
        name.write (") <- (")

        first = True
        for bits in self.variantinfo[1]:
            if not first: name.write (", ")
            name.write (lowir.argflags.string (bits))
            first = False
            pass
        name.write (")")

        res = StringIO ()
        res.write ("{%08x:%-40s" % (id (self), name.getvalue()))
        del name
        res.write ("[")
        for val in self.dests: res.write (strs[val])
        res.write (" ] <- [")
        for val in self.sources: res.write (strs[val])
        res.write (" ]")
        res.write (" (%dp, %df)}" % (self.totalpass, self.totalfail))

        return res.getvalue ()

    def __str__ (self):
        return "VS{%08x}" % id (self)

    pass

def _clone_variant_state (varstate):
    return [ vs.clone() for vs in varstate ]

def _build_emit_apply_tree (log, varstate):

    indent = log.indent ("Build Emit Apply Tree")

    # We need to find the variant (if any) we want to apply.  We always
    # apply the variant that is closest to completion; that is, closest
    # to providing us with a complete match with no failures.  That
    # means that passes count for a match, but failures count against.
    # Said another way, the number of failed arguments is the primary
    # sort key (lower is better) and the number of matched arguments
    # is the secondary sort key.
    varsort = [ (vs.totalfail + vs.disqualified,
                 - vs.totalpass,
                 vs)
                for vs in varstate ]
    varsort.sort ()

    if log.mid_enabled():
        for xxx, yyy, vs in varsort: log.mid (vs.detail ())
        pass
    
    # See which arg should be applyed next.  We go through these in
    # desirability order; stopping when we get to the next varstate 
    for xxx, yyy, vs in varsort:

        log.low ("vs=%s" % vs)

        if vs.completely_matched():
            node = _EmitTreeResult (vs.code_gen_func(),
                                    ())  # exact match: no wrappings
            log.low ("completely matched node=%s", node)
            del indent
            return node
        elif not vs.completely_applyed():
            argidx, desired = vs.find_next_arg_index_to_apply ()
            applynode = _EmitTreeApply (argidx, desired)

            log.low ("not completely applyed: node=%s", applynode)

            # Create the state where that node passes or fails.  Note
            # that applying if other variants have the same conditions
            # on this argument they may have been affected:
            passstate = _clone_variant_state (varstate)
            failstate = _clone_variant_state (varstate)
            idx = 0
            for setvs in varstate:
                applybits = setvs.get_apply_bits (argidx)
                if applybits == desired:
                    passstate[idx].set (argidx, _P)
                    failstate[idx].set (argidx, _F)
                elif (applybits & desired) == 0:
                    # if there is no overlap here, and we passed,
                    # then this must not match
                    passstate[idx].set (argidx, _F)
                    pass
                idx += 1
                pass
            del varstate # our own state is no longer needed
            
            # Now recurse
            applynode.set_pass (_build_emit_apply_tree (log, passstate))
            applynode.set_fail (_build_emit_apply_tree (log, failstate))
            del indent
            return applynode
        log.low ("completely applyed, not a match")
        pass

    # At this point no complete matches were found but everything is
    # completely applyed, so we want to return the first item in the list
    # since it is closest to the current state.  However, we must also
    # construct the set of wrappings that will be needed for those sources
    # and destinations that failed.
    for xxx, yyy, vs in varsort:
        if not vs.disqualified:
            assert not vs.completely_matched ()
            assert vs.completely_applyed ()
            break
        pass
    else:
        raise util.PyntoException ("Could not find matching instruction")
    node = _EmitTreeResult (vs.code_gen_func(), vs.failed_wrappers())
    log.low ("using imperfect vs %s", vs)
    log.low ("created wrappings etc for node %s" % node)
    del indent
    return node

class _EmitTreeApply:

    """ An _EmitTreeApply is a node in the decision tree that indicates
    a apply to conduct on the type of one of the arguments to the
    Operation.  The index of the argument is specified as well as a
    bit mask of legal Argument types.

    If, when doing a apply, the argument is of one of the types in
    desired_types we got to self.passnode, otherwise self.failnode.

    Argument indices > 0 indicate sources, < 0 indicate destinations.
    They are offset by 1 because 0 is not a valid argument index. """

    def __init__ (self, argidx, desired_types):
        if argidx > 0:
            self.listidx = 1
            self.idx = argidx-1
        else:
            self.listidx = 0
            self.idx = -argidx-1
            pass
            
        self.desired_types = desired_types
        self.passnode = None
        self.failnode = None
        pass

    def set_pass (self, node):
        self.passnode = node
        return

    def set_fail (self, node):
        self.failnode = node
        return

    def apply (self, args):

        """ Routine that performs the apply at runtime. """
        
        list = args[self.listidx]
        arg = list[self.idx]
        if arg.is_assigned_to_type (self.desired_types):
            return self.passnode.apply (args)
        return self.failnode.apply (args)

    def __str__ (self):
        return "Apply(%08x,%s%d,%s)" % \
               (id(self),
                self.attr[0],
                self.idx,
                lowir.argflags.string (self.desired_types))

    def dump (self, res, indent):
        indentstr = " " * indent
        res.write (indentstr)
        res.write ("Apply: idx %s%d types %s\n" %
                   (self.attr[0],
                    self.idx,
                    lowir.argflags.string (self.desired_types)))
        res.write (indentstr)
        res.write ("Pass:\n")
        self.passnode.dump (res, indent+2)
        res.write (indentstr)
        res.write ("Fail:\n")
        self.failnode.dump (res, indent+2)
        pass

    pass

class _EmitTreeResult:

    """ An _EmitTreeResult is a node in the decision tree that
    indicates a decision has been reached.  It includes a code
    generation function and a list of wrappers.  The wrappers are a
    list of pairs, each pair containg an argument index and a class
    constructor.  This indicates that the sources and destinations
    indicated by the argument indices must be wrapped in an instance
    of the class constructor before they can be used.

    So, if a source argument is a MemArgument but a RegArgument is
    required, then the relevant wrapper would be (1, SoftRegArgument)
    where '1' would indicate the first source argument.

    Argument indices > 0 indicate sources, < 0 indicate destinations.
    They are offset by 1 because 0 is not a valid argument index. """

    def __init__ (self, cgfunc, wrappers):
        self.cgfunc = cgfunc
        self.wrappers = wrappers
        pass

    def __str__ (self):
        return "Res(%08x,%s,%d wraps)" % \
               (id(self),
                self.cgfunc,
                len (self.wrappers))

    def apply (self, args):

        """ Routine for runtime that returns the cgfunc and wrappers """

        return (self.cgfunc, self.wrappers)

    def dump (self, res, indent):
        indentstr = " " * indent
        res.write (indentstr + ("Result:\n"))
        res.write (indentstr + ("  CG Func: %s\n" % self.cgfunc))
        for wrap in self.wrappers:
            res.write (indentstr + ("  Wrap arg %d in %s\n"
                                    % (wrap[0], wrap[1].__name__)))
            pass
        pass

    pass

class EmitTable:

    def __init__ (self, log, emithash):

        self.trees = {}
        self.nontrees = {}
        self.log = log

        for opcode, allvariants in emithash.items():

            # We want to create pseudo-entries for each opcode that
            # encode the opcode and number of sources and dests.

            def _construct_key (var):
                return "%s_%d_%d" % (opcode, len(var[0]), len(var[1]))

            varsbykey = {}
            for var in allvariants:
                key = _construct_key (var)
                try:
                    varsbykey[key].append (var)
                except KeyError:
                    varsbykey[key] = [ var ]
                    pass
                pass

            for key, variants in varsbykey.items():

                # Compute total number of arguments
                totargs = len(variants[0][0]) + len(variants[0][1])
                if totargs > 8:
                    # If there are too many arguments then the trees
                    # become too large.
                    self.nontrees[key] = variants
                    continue

                # Create the initial state.
                varstate = [ _VariantState (var) for var in variants ]
                for vs in varstate: vs.initial ()

                # Create the tree.
                tree = _build_emit_apply_tree (log, varstate)

                # Put in our table.
                self.trees[key] = tree
                pass

            pass
        
        pass

    def apply_to_oper (self, oper):

        """ convenience function.  see apply """

        return self.apply (oper.opcode, oper.dests, oper.sources)

    def apply (self, opcode, dests, sources):
    
        """ Given an Operation's opcode, destination, and source
        arguments, looks it up in the table.  Returns a triple
        (cgfunc, destwraps, sourcewraps).  

        cgfunc is the code generation function pointer for this
        Operation.

        destwraps and sourcewraps are a list of tuples (argidx,
        tempclass).  argidx is an index in the list of destination
        or source arguments, and tempclass is the Class object of the
        kind of argument that needs to be substituted.  tempclass will
        be a SoftRegArgument or SoftStackArgument in practice, but you
        should be able to just instantiate it with no arguments.
        
        The caller is responsible for assuring that values are moved
        into instances of the appropriate class and that those instances
        are referenced in the Operation. """

        key = "%s_%d_%d" % (opcode,
                            len (dests),
                            len (sources))
        
        try:
            cgfunc, wraps = self.trees[key].apply ( (dests, sources) )
        except KeyError:
            try:
                cgfunc, wraps = self._interpret_list (dests,
                                                      sources,
                                                      self.nontrees[key])
            except KeyError:
                raise util.PyntoException (
                    "No entry in emit table for %s w/ %d dsts and %d srcs"
                    % (opcode, len (dests), len (sources)))
            pass
        
        if not wraps:
            return (cgfunc, (), ()) # nothing to do

        # These two lists hold the sets of pairs that become our
        # return value.
        sourcerets = []
        destrets = []

        # Go through the list of wrapped arguments and perform the
        # substitutions required:
        for argidx, wrapclass in wraps:
            if argidx > 0:
                argidx = argidx - 1
                sourcerets.append ( (argidx, wrapclass) )
            else:
                argidx = -argidx - 1
                destrets.append ( (argidx, wrapclass) )
                pass
            pass

        return ( cgfunc, destrets, sourcerets )

    def _interpret_list (self, odests, osources, variants):

        """ If the number of arguments is too large, we don't build
        a decision tree.  Instead, we store each possibility and use this
        algorithm which iterates down and finds the closest match.

        We need to return the same format as the tree, which means a
        cgfunc and a list of (argidx, WrapperClass) pairs. """

        bestmatch = None
        bestmatchcount = 0

        numdsts = len (odests)
        numsrcs = len (osources)

        for var in variants:
            
            dsts, srcs, cgfunc = var

            assert len (dsts) == numdsts
            assert len (srcs) == numsrcs

            matches = 0
            mismatches = 0
            constfailure = 0
            
            for spec, arg in zip (dsts, odests):
                if arg.is_assigned_to_type (spec): matches += 1
                else: mismatches += 1
                pass

            # Note: we can't convert anything to a constant, so remember
            # if we had to do something like that.
            for spec, arg in zip (srcs, osources):
                if arg.is_assigned_to_type (spec): matches += 1
                elif spec == lowir.argflags.CONST: constfailure = 1
                else: mismatches += 1
                pass

            if not mismatches:
                # A perfect match!
                self.log.mid ('_emit perfect match')
                return (cgfunc, ())
            elif matches > bestmatchcount and not constfailure:
                bestmatch = var
                bestmatchcount = matches
                pass
            pass

        # If this was not a perfect match, we need to determine which
        # arguments must be substituted and what with
        dsts, srcs, cgfunc = bestmatch

        wrappers = []
        argidx = 1
        for spec, arg in zip (srcs, osources):
            if arg.is_assigned_to_type (spec): continue
            wrappers.append ( (argidx, _select_class (spec)) )
            argidx += 1
            pass

        argidx = -1
        for spec, arg in zip (dsts, odests):
            if arg.is_assigned_to_type (spec): continue
            wrappers.append ( (argidx, _select_class (spec)) )
            argidx -= 1
            pass

        return (cgfunc, wrappers)

    def dump (self, out):
        out.write ("Emit Table: {\n")
        out.write ("  Emit Tree: {\n")
        for key, tree in self.trees.items ():
            out.write ("    %s:\n" % key)
            tree.dump (out, 6)
            pass
        out.write ("  }\n")
        out.write ("  Emit Lists: {\n")
        for key, variants in self.nontrees.items ():
            out.write ("    %s:\n" % key)
            for dsts, srcs, cgfunc in variants:
                out.write ("      * Dsts: %s\n" %
                           [ lowir.argflags.string(x) for x in dsts ])
                out.write ("        Srcs: %s\n" %
                           [ lowir.argflags.string(x) for x in srcs ]) 
                out.write ("        CGF : %s\n" % cgfunc)
                pass
            pass
        out.write ("  }\n")
        out.write ("}\n")
        pass                       
    
    pass
