"""

knodes.py
=========

Contains the definition of the knodes.  The knodes are so called because
they take part in the knowledge tree, or ktree for short.  The ktree encodes
everything that we know --- or think we know! --- about the running
python environment.  This knowledge allows us to handle attribute and globals
lookups and convert them into more efficient operations.

Each knode has a set of children; these are always the children that
would be obtained by doing an attribute lookup on the object that the knode
represents.  So, for example, if a particular knode ``a`` represents a
module, then any children that ``a`` might have represent global variables
in that module.

Knowledge on the knowledge tree is not permanent; if, for example, module
``a`` has a binding that says that ``g`` is a constant function, but
``g`` is overwritten, then the ktree has changed.  To prevent this from
fouling us up, each edge on the ktree also tracks which pieces of jitted
code depend on the assumption represented by that edge (that something
with a particular label is of the type pointed at by the edge) so they can
be thrown out when the assumption changes.

There are five kinds of knodes, each with different properties in addition
to children:

ModuleKNode
-----------

Corresponds to a loaded module.  Has no extra attributes, only the set
of children (or global vars).

ClassKNode
----------

Corresponds to a class object.  In addition to the normal children,
which are attributes of the class object (and hence all instances, at
least at first), it also contains a list of children that correspond
to attributes we have found on instances of this class.

FunctionKNode
-------------

Corresponds to a function or unbound method object.  In addition to
the normal children, contains a standard python CodeObject and a set
of jitted code.  Also contains a list of knodes for the function
parameters, including any free parameters.  In the future it will
probably contain default parameter values or some such nonsense.

ConstantKNode
-------------

Contains a pointer to a constant operand.  See operands.py.  Used when
a name maps to a constant.

InstanceKNode
-------------

Contains a list of class knodes.  These indicate possibly types for
the object represented by this knode.  Also has a bit indicating
whether other types are possible.

"""

import types, sys
import util, values, stacksim, typeset
from jitted import JittedVersion

class KEdge:

    """ KEdge represents an edge on the ktree.  It contains a label,
    the node at the tail of the edge, and a list of dependents which
    used this assumption.  If the edge is ever removed from the tree,
    those dependents must be updated.

    There are no accessors for this class because it's just not worth it. """
    
    def __init__ (self, label, node):
        self.label = label
        self.node = node
        self.dependents = ()
        pass

    def add_dependent (self, dep):
        try:
            self.dependents.append (dep)
        except AttributeError:
            self.dependents = [ dep ]
            pass
        pass

    def clone (self):
        return KEdge (self.label, self.node.clone ())

    pass

class KNode:

    """ KNode is the base class for all the knodes (!).  It contains
    the mechanisms for tracking children. """

    def __init__ (self, parent):
        self.children = None
        self.parent = parent
        pass

    # Type inspection:
    
    def is_constant (self):
        return False

    def is_instance (self):
        return False

    def is_module (self):
        return False

    def is_class (self):
        return False

    def is_function (self):
        return False

    def get_type_name (self):

        """ Returns a string representing what type of knode this is """
        
        return self.__class__.__name__

    def get_parent (self):

        """ Returns our parent in the ktree, or None if this is the root
        of the tree """
        
        return self.parent

    def get_child (self, label):

        """ Returns the child with the given label, or None if there
        is no such child.  The child will be a KEdge instance. """
        
        if not self.children: return None
        return self.children.get (label, None)

    def get_children (self):

        """ Returns a list of all KEdge children added. """
        
        if not self.children: return ()
        return self.children.values ()

    def add_child (self, edge):

        """ Adds an edge to our list of children, overwriting any
        existing edge with the same label. """

        assert isinstance (edge, KEdge)
        if not self.children: self.children = {edge.label: edge}
        else: self.children[edge.label] = edge
        return

    def get_root_node (self):

        """ Returns the root node of the tree """

        node = self
        while node.parent: node = node.parent
        return node

    def get_module_node (self):
        
        """ Returns the closest module containing us """
        
        node = self
        while not node.is_module ():
            node = node.parent
            if not node: return None
            pass
        return node

    def get_value (self):

        """ If this knode can be represented by a constant Value object
        when its value is loaded, this returns that constant.  Otherwise
        returns None. """

        return None

    def get_types (self):

        """ Returns a TypeSet containing information about the possible
        types of the object this knode represents.  Returns None 
        if no type information is derivable."""

        val = self.get_value ()
        if val: return val.get_types ()
        return None

    pass

class RootKNode (KNode):

    """ The RootKNode contains all modules.  It also has a couple of
    rather constant types, like a hash of knodes for obtaining constant
    values. """

    def __init__ (self):

        KNode.__init__ (self, None)
        
        # Contains knodes for each of the built-in types, used
        # in get_constant()
        self.typenodes = {}

        # Initialize the typenodes dict with the default python types
        ignoretypes = (types.InstanceType,)
        for t in types.__dict__.values():
            if type(t) is types.TypeType and t not in ignoretypes:
                typenode = ClassKNode (None, t.__name__, True)
                self.typenodes[t] = typenode
                pass
            pass

        # There are a number of kinds of closures in python:
        # functions, bound instance methods, etc.  Unfortunately, none
        # of the types in the types module seem to quite correspond to
        # a closure.  So, to handle this, we just create our own
        # closure type node.  All closures have this type; to
        # determine the function and implied arguments associated with
        # a closure, you must examine the instruction that created it.
        self.closure_type = ClassKNode (None, "Closure", True)

        return

    def get_closure_type (self):
        return self.closure_type

    def get_builtin_type (self, type):
        return self.typenodes[type]

    def get_builtin_types (self):
        return self.typenodes.values ()

    def get_typed_constant (self, type, constant):
        return self.typenodes[type].get_constant (constant)

    pass

class ModuleKNode (KNode):

    """ ModuleKNodes represent modules.  They add no extra data to the base
    KNode. """

    def __init__ (self, root):
        KNode.__init__ (self, root)
        return

    def is_module (self):
        return True

    def get_value (self):
        return self.get_root_node().get_typed_constant (types.ModuleType, self)

    pass

class ClassKNode (KNode):

    """ ClassKNodes represent class objects or other type objects, such
    as the built-in types. """

    def __init__ (self, parent, name, builtin):
        KNode.__init__ (self, parent)
        self.name = name
        self.instanceonly = {}     # KEdge objects, organized just
				   # like children
        self.constants = None
        self.builtin = builtin
        return

    def is_class (self):
        return True

    def is_linear_subtype (self, otherklass):

        """ Returns true if self is a subclass of otherklass, and every
        ancestor of self up to otherklass has only one ancestor class.
        i.e., single inheritance """

        return self is otherklass # insufficient

    def is_subtype (self, otherklass):

        """ Returns true if self is a subclass of otherklass through
        any means (multiple inheritance OK) """

        return self is otherklass # not sufficient

    def is_builtin (self):
        return self.builtin

    def get_name (self):
        return self.name

    def lookup (self, attrname):
        
        """ Simulates an attr lookup on the object on the ktree; this
        means that we first check for an instance attribute by that
        name.  If none is found, we check for a class-wide attribute by
        that name.  If still nothing is found, we return None. """

        # First check on the instance
        res = self.get_instance_attr (attrname)
        if res: return res

        # Next check on the class; note that loading a function from
        # a class does not yield the function but instead yields
        # a closure with the self pointer.  To handle this we generate
        # a temporary knode of type closure.
        res = self.get_child (attrname)
        if res and res.node.is_function ():
            temp = InstanceKNode (self)
            temp.get_types().add_hint (typeset.EXACT, 
                                       self.get_root_node().get_closure_type())
            res = KEdge (res.label, temp)
            pass
        return res

    def get_instance_attrs (self):
        return self.instanceonly.values ()

    def add_instance_attr (self, edge):
        self.instanceonly[edge.label] = edge
        return

    def get_instance_attr (self, name):
        return self.instanceonly.get (name, None)

    def get_constant (self, val):

        """ Given a value of this type, returns a constant.  The
        constants are considered equivalent on the basis of whether an
        == comparison would consider them equivalent.  Note that for
        equivalent constants the same symbol is always returned. This
        is usually used only when this ClassKNode represents a
        built-in type. """

        try:
            if self.constants: return self.constants[val]
        except KeyError: pass
        sym = values.ConstantValue (self, val)
        if self.constants:
            self.constants[val] = sym
        else:
            self.constants = {val:sym}
            pass
        return sym

    def get_value (self):
        return self.get_root_node().get_typed_constant (types.ClassType, self)

    def __str__ (self):
        return "class<%x,%s>" % (id (self), self.name)

    def __repr__ (self):
        return "ClassKnode<%s>" % self.name

    pass

class FunctionKNode (KNode):

    """ FunctionKNodes represent functions.  """

    def __init__ (self, parent, codeobj):
        KNode.__init__ (self, parent)
        self.codeobject = codeobj
        self.jittedversions = []     # 
        self.parameters = []         # KEdge objects
        self.subfunctions = []
        self.allknown = False
        self.returntype = typeset.TypeSet ()

        i = 0
        while i < codeobj.co_argcount:
            name = codeobj.co_varnames[i]
            node = InstanceKNode (self)
            edge = KEdge (name, node)
            self.parameters.append (edge)
            i += 1
            pass
        pass

    def is_function (self):
        return True

    def get_parameters (self):

        """ Returns a set of KEdge objects containg the information
        we know about the types for these parameters. """
        
        return self.parameters

    def get_all_known (self):
        return self.allknown

    def set_all_known (self, allknown):
        """ Indicates whether we know about all calls to this function;
        not true for functions accessible from non-jitted code.  This
        affects the corresponding settings on the parameters. """

        self.allknown = allknown
        for edge in self.parameters: edge.node.set_all_known (allknown)
        return

    def get_return_types (self):

        """ returns a typeset representing what we know about the return
        value of this function.  """
        
        return self.returntypes

    def get_code_object (self):
        return self.codeobject

    def get_jitted_versions (self):
        return self.jittedversions

    def get_parameters (self):
        return self.parameters

    def external_jitted_version (self):

        """ Returns the jitted version which makes no assumptions about its
        parameters, if any.  This is the version which can be invoked by
        outsiders (non-JITted code). """

        for jitted in self.jittedversions:
            for argtype in jitted.get_param_assumptions():
                if not argtype.node.get_types ().know_nothing(): break
                pass
            else:
                return jitted
            pass
        return None

    def select_jitted_version (self, comp, instr):

        """ Given a CallInstruction which is known to call this
        function, selects a specified JittedVersion to call.  This
        selection is based on the types of the supplied arguments ---
        we selected the jitted version whose assumptions most closely
        match what we know.  """

        # Construct a list of the types for each of the actual arguments
        posargtypes = [param.value.get_types()
                       for param in instr.get_named (instr.ARGS)]

        # Initialize 
        bestmatch = None
        bestmatchcnt = 0

        ind = comp.log.indent ('select_jitted_version instr=%s', instr)

        for jitted in self.jittedversions:

            ind2 = comp.log.indent ('jitted')
            
            # Walk through the assumptions and parameters in lock step.
            matchcnt = 0

            for argtype, assump in zip (posargtypes,
                                        jitted.get_param_assumptions()):
                types = assump.node.get_types ()

                comp.log.low ('argtype=%s types=%s (node=%s)', argtype, types,
                              assump.node)

                # If this any argtype is not a subset, this jitted is
                # not even a possibility, so just break out of the
                # comparison loop
                if not argtype.is_subset_of (types):
                    comp.log.low ('Actual argument not a subset jitted arg')
                    break
                
                # Otherwise, we want to rank how specific this version
                # is.  We want to find the JittedVersion that is the
                # most specific to the types involved.  For now we'll
                # just count the one that has the least "unknown types"
                if argtype.get_unknown ():
                    matchcnt += 1
                    comp.log.low ('Formal argument unknown (matchcnt=%d)',
                                  matchcnt)
                    pass

                pass
            else:
                # If we get here, we never found any arguments that
                # didn't match.
                comp.log.low ("Found possibly jitted w/ matchcnt=%d", 
                              matchcnt)
                if not bestmatch or matchcnt < bestmatchcnt:
                    comp.log.mid ('Best match w/ count=%d', bestmatchcnt)
                    bestmatch = jitted
                    bestmatchcnt = matchcnt
                    pass
                pass
            pass

        # This should not be possible because we should always generate a
        # non-optimized boring version
        if not bestmatch:
            raise util.PyntoException ("No matching Jitted Version found")
        
        return bestmatch

    def get_sub_functions (self):

        """ Returns knodes for all sub functions defined by this function """
        
        return map (lambda x: x[1], self.subfunctions)

    def get_sub_function (self, codeobj):

        """ Returns the sub-function node corresponding to the given
        codeobject; returns None if none is found. """
        
        for co, knode in self.subfunctions:
            if co is codeobj: return knode
            pass
        return None

    def get_value (self):
        # we'll have to figure out the codeobject vs function vs
        # method thing...
        return self.get_root_node().get_typed_constant (types.CodeType, self)

    def build (self, comp, paramtypes):

        """ Called when a function knode is first created.  Performs
        an initial build of the IR to generate the rudimentary IR. """

        start, stop = stacksim.build_instructions (comp,
                                                   self.get_code_object())
        result = JittedVersion (
            comp.log, self.get_name(), paramtypes, start, stop)
        self.jittedversions.append (result)
        return

    def get_name (self):
        return self.codeobject.co_name

    def __str__ (self):
        return "<FunctionKNode: %s>" % (self.get_name()) 

    def __repr__ (self):
        return "<FunctionKNode: %s>" % (self.get_name()) 

    pass

class ConstantKNode (KNode):

    def __init__ (self, const):
        KNode.__init__ (self, None)
        self.constant = const
        pass

    def is_constant (self):
        return True

    def get_value (self):
        return self.constant

    def clone (self):
        return self

    pass

class InstanceKNode (KNode):

    def __init__ (self, parent):
        KNode.__init__ (self, parent)
        self.types = typeset.TypeSet ()
        pass

    def is_instance (self):
        return True

    def get_types (self):
        return self.types

    def clone (self):
        res = InstanceKNode (self.get_parent())
        res.types.copy_from (self.types)
        return res

    pass

# ----------------------------------------------------------------------
# Additional Methods

classes = globals().values()

import compiler_methods
util.add_external_methods (KNode, classes, "basic", compiler_methods)
util.add_external_methods (KNode, classes, "refine", compiler_methods)
util.add_external_methods (KNode, classes, "codegen", compiler_methods)
import dump_methods
util.add_external_methods (KNode, classes, "dump", dump_methods)
