import re, sys
import astast as ast
#import scfparse
#from scfparse import Grammar, ParseFailed, pretty_print_error

#scfparse.debug_print(globals())

class NoNonterminal(Exception):
    pass

class InvalidLine(Exception):
    pass

class MissingQuote(Exception):
    def __init__(self, u_c):
        self.u_c = u_c

class UnexpectedEndOfString(Exception):
    pass

def val(x):
    return lambda _: x

inv_token_chars = ast.punctuation+ast.reserved_chars+u" \r\n\t"

def stripcomments(file):
    """ Returns a list of lines for files (including the newline),
    where any line that has a # in the first column is reduced to an
    empty line.  We only support full line comments in this way so as
    to allow # to be used in type judgements and elsewhere."""
    def helper(line):
        if line[0] == '#': return "\n"
        return line
    lines = file.readlines()
    return [helper(l) for l in lines]

class ProdParser(object):
    def __init__(self, u_string):
        self.u_string = u_string
        self.skip_space()

    def skip_while(self, start_idx, pred):
        idx = start_idx
        while idx < len(self.u_string) and pred(self.u_string[idx]):
            idx += 1
        u_skipped = self.u_string[:idx]
        self.u_string = self.u_string[idx:]
        return u_skipped

    def skip_space(self, start_idx = 0):
        self.skip_while(start_idx, lambda c: c.isspace())

    def try_(self, u_text):
        if self.u_string.startswith(u_text):
            self.u_string = self.u_string[len(u_text):]
            return True
        return False
        
    def grouped(self, func):
        if self.try_(u"("):
            items = []
            while not self.try_(u")"):
                items.append(self.item())
            return ast.Production(items)
        return func()

    def prod(self):
        items = []
        while self.u_string:
            item = self.sub()
            items.append(item)
        return ast.Production(items)

    def item(self):
        return self.sub()

    def sub(self):
        sup = self.sup()
        if not self.try_(u"_"):
            return sup
        idx = self.grouped(self.sup)
        return ast.ItemSubscript((sup, idx))
        
    def sup(self):
        sup = self.ap()
        if not self.try_(u"^"):
            return sup
        idx = self.grouped(self.ap)
        return ast.ItemSupscript((sup, idx))

    def ap(self):
        res = self.rep()
        u_primes = self.skip_while(0, lambda c: c == u"'")
        if u_primes:
            res = ast.ItemPrime((res, u_primes))
        return res

    def rep(self):
        if self.try_(u"{"):
            items = []
            while not self.try_(u"}"):
                items.append(self.item())
            return ast.RepeatedItem(items)
        
        return self.any()

    def any(self):
        if not self.u_string:
            raise UnexpectedEndOfString()

        if self.try_(u"\\"):
            u_char = self.u_string[0]
            self.u_string = self.u_string[1:]
            return ast.ItemPunctuation(u_char)
        
        if self.u_string[0].isspace():
            self.skip_space()
            return ast.ItemSpace()
        
        if self.u_string[0] in ast.punctuation:
            u_char = self.u_string[0]
            self.u_string = self.u_string[1:]
            return ast.ItemPunctuation(u_char)

        if self.u_string[0] == u'"':
            u_text, u_sep, u_remainder = self.u_string[1:].partition(u'"')
            if not u_sep:
                raise MissingQuote(u'"')
            self.u_string = u_remainder
            return ast.ItemKeyword(u_text)

        if self.u_string[0] == u'`':
            u_text, u_sep, u_remainder = self.u_string[1:].partition(u'`')
            if not u_sep:
                raise MissingQuote(u'`')
            self.u_string = u_remainder
            return ast.ItemLiteral(u_text)

        u_id = self.skip_while(1, lambda c: c not in inv_token_chars)
        return ast.ItemIdentifer(u_id)

def parse_prod(u_string):
    u_items, u_sep, u_label = u_string.partition(ur'\\\\')
    prod = ProdParser(u_items).prod()
    if u_sep: prod.set_label(u_label)
    return prod

def grparse(filenm, file):
    """ file - a 'file like object' """

    # A grammar file looks like:
    #
    # # Comments
    # Option (option arg)
    #
    # # IMPORTANT: No spaces in list of non-terminals!  A hack, yes.  Sue me.
    # NT{,NT} := production 
    #         |  production \\\\ line break, comments
    #         |  production

    options = {}
    nonterms = []
    lastnt = None

    try:
        for line_num, line in enumerate(stripcomments(file)):
            line_num += 1
            words = line.split()

            if not line[0].isspace(): # either an option or a NT

                if words[1] != u":=": # option
                    options[words[0]] = words[1:]

                elif words[1] == u":=": # non-terminal
                    nt_name = words[0].split(u",")
                    prod_str = u" ".join(words[2:])
                    prod_obj = parse_prod(prod_str)
                    lastnt = ast.Nonterminal(nt_name, [prod_obj])
                    nonterms.append(lastnt)
                    
                continue

            if not words: # empty line
                continue

            if words[0] != u'|':
                raise InvalidLine()

            if not lastnt:
                raise Nonterminal()

            prod_str = u" ".join(words[1:])
            prod_obj = parse_prod(prod_str)
            lastnt.prods.append(prod_obj)

        return ast.Grammar(options, nonterms)

    except UnexpectedEndOfString:
        sys.stderr.write("%s:%d: unexpected end of line" % (
                filenm, line_num))

    except MissingQuote, m:
        sys.stderr.write("%s:%d: missing end quote %r" % (
                filenm, line_num, m.u_c))

    except NoNonterminal:
        sys.stderr.write("%s:%d: no non-terminal name found" % (
                filenm, line_num))

    except InvalidLine:
        sys.stderr.write("%s:%d: invalid line, should define or extend a non-terminal" % (
                filenm, line_num))
                             
def tsparse(filenm, file):
    """ file - a 'file like object' """
    lines = stripcomments(file)

    # A type system file looks like:
    #
    # Options
    #
    # [ judgement ]
    #
    # Name:
    #    Premise
    #    Premise
    #    Premise
    #    --------
    #    Conclusion
    #    Conclusion
    #
    # where judgement, premise, and conclusion (for now) are represented
    # by the "production" nonterminal from above.

    class NoJudgement(Exception):
        pass

    class NoConclusion(Exception):
        pass

    typesys = ast.TypeSystem([])
    options = {}
    judgement = None
    typerule = None
    dash = False    

    try:
        for line_num, line in enumerate(lines):
            line_num += 1

            if line[0] == u'[':

                # Once we find the first judgement, no more options permitted
                if options: 
                    typesys.set_options(options)
                    options = None

                close_brace = line.rindex(u']')
                heading_text = line[1:close_brace].strip()
                judgement = ast.TypeJudgement(heading_text, [])
                typesys.judgements.append(judgement)
                continue

            if not line[0].isspace():
                
                if not judgement:
                    # No judgement yet: accumulate options!
                    words = line.split()
                    options[words[0]] = words[1:]
                    continue

                colon = line.rindex(":")
                if typerule and not typerule.conclusions:
                    raise NoConclusion()
                typerule = ast.TypeRule(line[:colon], [], [])
                judgement.rules.append(typerule)
                dash = False
                continue

            line = line.strip()
            if not line:
                continue
            
            if line.startswith('-'*5):
                if dash:
                    sys.stderr.write("%s:%d: Already saw dashed line\n" % (
                        filenm, line_num))
                    return None                
                dash = True
                continue
            
            line_prod = parse_prod(line)
            if not dash:
                if not typerule:
                    sys.stderr.write("%s:%d: Premise without type rule\n" % (
                        filenm, line_num))
                    return None
                typerule.premises.append(line_prod)
                continue

            typerule.conclusions.append(line_prod)
            
        if typerule and not typerule.conclusions:
            raise NoConclusion()
            
        return typesys

    except ValueError:
        sys.stderr.write("%s:%d: Missing terminator\n" % (
            filenm, line_num))
        return None

    except NoConclusion:
        sys.stderr.write("%s:%d: Previous rule has no conclusion\n" % (
            filenm, line_num))
        return None
    
    except NoJudgement:
        sys.stderr.write("%s:%d: No judgement declared\n" % (
            filenm, line_num))
        return None
