"""
ast_builder.py
"""

import sys

# ----------------------------------------------------------------------
# Exceptions

class AstBuilderException(Exception):
    pass

class InvalidClassMember(AstBuilderException):
    def __init__(self,mem):
        self.mem = mem

    def __str__(self):
        return "Class member '%s' of unrecognized type." % self.mem

class NoInitializerSpecified(AstBuilderException):
    def __str__(self):
        return ("Non-ast item has no initialized specified." +
                "  Use initial_value= or initial_expr=.")

class BothInitializersSpecified(AstBuilderException):
    def __str__(self):
        return ("Non-ast item has both initializers specified." +
                "  Use only one of initial_value= and initial_expr=.")

# ----------------------------------------------------------------------
# Representation of user's classes

class UserCodeObject(object):
    pass

class EnumNode(UserCodeObject):
    def __init__(self, name, prefix, options):
        self.name = name
        self.prefix = prefix
        self.options = options
    pass

class AstNode(UserCodeObject):
    def __init__(self, name, generics, subclasses, nonasts, refs):
        """
        * name: name of AstNode
        * generics: a list of any Generic objects
        * subclasses: other AstNodes that are subclasses of this one
        * nonasts: list of fields of the class that are not references to
        managed types
        * refs: references to ast nodes, enumerations, or other types that
        we manage
        """
        self.name = name
        self.generics = generics
        self.parclass = None           # assigned by parent class later
        self.subclasses = subclasses
        self.nonasts = nonasts
        self.references = refs
        self.reference_start = 0
        for subcl in subclasses:
            subcl.parclass = self
            subcl.increment_reference_start(len(refs))
        return

    # reference_count: the total number of references for this node,
    # including any from the super class
    def get_reference_count(self):
        return self.reference_start + len(self.references)
    reference_count = property(fget=get_reference_count)

    def increment_reference_start(self, amnt):
        self.reference_start += amnt
        for c in self.subclasses:
            c.increment_reference_start(amnt)
        return
    
    pass

def ast_meta(name, bases, dict):
    references = []
    nonasts = []
    subcl = []
    custom_code = None
    
    for k,v in dict.items():
        if isinstance(v,Reference):
            references.append((k,v))
        elif isinstance(v,NonAst):
            nonasts.append((k,v))
        elif isinstance(v,AstNode):
            subcl.append((k,v))
        elif k == "__custom__":
            custom_code = str(v)
        else:
            raise InvalidClassMember(k)
    
    return AstNode(name, subcl, fields, references, custom_code)

def enum_meta(name, bases, dict):
    pre = dict.get("prefix", "")
    options = dict[options]
    return EnumNode(name, pre, options)

class Reference(UserCodeObject):
    def __init__(self, name):
        self.name = name
        self.specialized_by = ()

    def __call__(self, specialize):
        if not self.specialized_by:
            self.specialized_by = []
        self.specialized_by.append(specialize)
        return self

class ReferenceBuilder(object):
    def __getattr__(self, attrname):
        return Reference(attrname)

class Generic(object):
    def __init__(self, name):
        self.name = name
    pass

class NonAst(UserCodeObject):
    def __init__(self, type_, initial_value=None, initial_expr=None):
        self.type = type_

        if not initial_value and not initial_expr:
            raise NoInitializerSpecified()
        if initial_value and initial_expr:
            raise BothInitializersSpecified()

        self.initial_value = initial_value
        self.initial_expr = initial_expr

# ----------------------------------------------------------------------
# C++ output mode
#

kBaseClassExtra = \
"""
    public:
        %(class_enum_name)s class_id() {
            return _ast_class_id;
        }
        
        %(class_iter_name)s<%(ast_node_name)s> children() {
            return %(class_iter_name)s<%(ast_node_name)s>(_ast_num_children,
                                                          _ast_gen_children);
        }
        
    protected:
        void init(%(class_enum_name)s class_id,
                  %(ast_node_name)s *ast_children,
                  int numchildren)
        {
            _ast_class_id = class_id;
            _ast_gen_children = ast_children;
            _ast_num_children = num_children;
        }
        %(class_enum_name)s _ast_class_id;
        %(ast_node_name)s *_ast_gen_children;
        int _ast_num_children;
"""

class CppDump:
    def __init__(self, stream, astbaseclass, enums):
        self.stream = stream
        self.indent = 0
        self.newline = True
        self.enums = enums
        self.astnodes = []

        # assemble list of all ast nodes with tree walk
        def add_ast_nodes(node):
            self.astnodes.append(node)
            for n in node.subclasses: add_ast_nodes(n)
        add_ast_nodes(astbaseclass)

        # Remember the base class
        self.ast_base_class = astbaseclass

        # Various configuration constants
        self.class_enum_name = "t_ast_node_types"
        self.max_ast_name = "k_ast_max"
        self.class_iter_name = "t_ast_iter"
        return

    def dump(self):
        self.dump_class_constants()
        self.dump_enums()
        self.dump_class_decls()
        return

    def _write(self, string):

        """ Writes to the output stream, respecting the indent level.
        Only allows a '\n' at the end of the string.  After writing,
        modifies indent by 'indent' """

        assert not '\n' in string[:-1]
        
        if self.newline:
            self.stream.write(' '*self.indent)
            self.newline = False
        self.stream.write(string)
        if string[-1] == '\n':
            self.newline = True
        if string[-2] == '{':
            self.indent += 4
        else if string[0] == '}':
            self.indent -= 4
        return

    def write(self, string):
        while string:
            try:
                nidx = string.index('\n')
                self._write(string[:nidx+1])
                string = string[nidx+1:]
            except ValueError:
                self._write(string)
                return
        return11

    def dump_class_constants(self):
        self.write('enum %s {\n' % self.class_enum_name)
        for astcl in self.astclasses:
            if astcl.subclasses: continue
            const = self._get_class_constant(astcl.name)
            self.write("%s,\n" % const)
        self.write(self.max_ast_name)
        self.write('}\n\n')
        return

    def dump_enums(self):
        for en in self.enums:
            self.write('enum %s {\n' % en.name)
            for op in en.options:
                self.write('%s%s,\n' % (en.prefix, op))
            self.write('};\n\n')
        return

    def dump_AstNode(self, node):

        if node.generics:
            self.write('template <')
            for g in node.generics: self.write('class '+g.name)
            self.write('>\n')
        
        self.write('struct %s ')
        
        if node.parclass:
            self.write(' : public %s' % node.parclass.name)
        self.write('\n{\n')

        # Emit the constructor (todo, allow custom code in here?):
        self.write('%s()\n{\n' % node.name)

        ## If this is not an abstract subclass, emit the call to init:
        if not self._is_abstract(node):
            self.write('init(%s, ' % self._get_class_constant(node.name)) 
            if self._has_children_array(node):
                self.write('_ast_my_children, %d)' % node.reference_count)
            else:
                self.write('NULL, 0')

        ## Emit initialization for any non-ast fields:
        for name, nonast in node.nonasts:
            if nonast.initial_value:
                self.write('%s = %s;\n' % (name, nonast.initial_value))
            else:
                self.write('%s;\n' % nonast.initial_expr)
        for name, ref in node.references:
            self.write('%s = NULL;\n' % name)

        ## Zero out the children array, if applicable
        if self._has_children_array(node):
            self.write(
                'memset(_ast_my_children, 0, sizeof _ast_my_children);\n')

        ## End the constructor
        self.write('}\n\n')

        # Emit any custom code

        # Emit any non-AST field declarations
        for name, nonast in node.nonasts:
            self.write('%s %s;\n' % (nonast.type, name))
        self.write('\n')

        # Emit accessors for each reference
        ctr = node.reference_start
        for name, ref in node.references:
            constant = "k_%s" % name
            reftype = self.reference_type_string(ref)
            self.write('static const int %s = %d;\n' % (constant, ctr))
            self.write('%s %s() {\n' % (reftype, name))
            self.write('return (%s)_ast_gen_children[%s];\n' % (
                reftype, constant))
            self.write('}\n')
            self.write('void %s(%s value) {\n' % (
                node.name, name, reftype))
            self.write('_ast_gen_children[%s] = value;\n' % constant)
            self.write('}\n')
            ctr += 1

        # If this is the base class, insert special base class stuff
        if self.ast_base_class is node:
            self.write(kBaseClassExtra % {
                "class_enum_name":self.class_enum_name,
                "class_iter_name":self.class_iter_name,
                "ast_node_name":self.ast_node_name})

        # Declare the children array, if applicable
        # Note that this is only referenced in the constructor
        if self._has_children_array(node):
            self.write('private:\n')
            self.write('%s _ast_my_children[%d];\n' % (
                self.ast_base_class.name, node.reference_count)

        self.write('};\n\n')
        return

    def _is_abstract(self, node):
        return not node.subclasses

    def _has_children_array(self, node):
        return not self._is_abstract(node) and node.reference_count

    def _get_reference_type(self, ref):
        if ref.specialized_by:
            return "%s<%s>*" % (ref.name, ",".join(ref.specialized_by))
        return "%s*" % ref.name

    def _get_class_constant(self, name):
        if name.startswith('t_'):
            return 'k_' + name[2:]
        return 'k' + name

    pass

# ----------------------------------------------------------------------
# Public Interface

__all__ = ("ast", "enum", "ref", "non_ast", "generic", "dump_cpp")

ref = ReferenceBuilder()

class ast:
    __metaclass__ = ast_meta

class enum:
    __metaclass__ = enum_meta

nonast = NonAst

generic = Generic

def dump_cpp(stream, *defns):
    outstr = sys.stdout
    asts = [ty for ty in defns if isinstance(ty, AstNode)]
    enums = [ty for ty in defns if isinstance(ty, EnumNode)]
    CppDump(asts, enums).dump()
