# -*- coding: utf-8 -*-

import sys, re

reserved_chars = u"|{}\\\"_'^`"
math_sym = {
    u'≠': r'\neq',
    u'≤': r'\leq',
    u'≥': r'\geq',
    u'∪': r'\cup',
    u'∩': r'\cap',
    u'∀': r'\forall',
    u'∃': r'\exists',
    u'∄': r'\nexists',
    u'∅': r'\emptyset',
    u'⊦': r'\vdash',
    u'⊧': r'\models',
    u'⊑': r'\sqsubseteq',
    u'⊏': r'\sqsubset',
    u'⊆': r'\subseteq',
    u'⊂': r'\subset',
    u'⊒': r'\sqsupseteq',
    u'⊐': r'\sqsupset',
    u'⊇': r'\supseteq',
    u'⊃': r'\supset',
    u'∈': r'\in',
    u'∉': r'\notin',
    u'→': r'\rightarrow',
    u'←': r'\leftarrow',
    u'⇒': r'\Rightarrow',
    u'⇐': r'\Lightarrow',
    u'⊕': r'\oplus',
    u'⊖': r'\ominus',
    u'⊥': r'\bot',
    u'⊤': r'\top',
    u'×': r'\times',

    u'\\': r'\setminus',
    u'|': r'\mid',
}

alphabeta = {
    u'α': r'\alpha',
    u'β': r'\beta',
    u'γ': r'\gamma',
    u'δ': r'\delta',
    u'ε': r'\epsilon',
    u'ζ': r'\zeta',
    u'η': r'\eta',
    u'θ': r'\theta',
    u'ι': r'\iota',
    u'κ': r'\kappa',
    u'λ': r'\lambda',
    u'μ': r'\mu',
    u'ν': r'\nu',
    u'ο': r'\omicron',
    u'π': r'\pi',
    u'ρ': r'\ro',
    u'σ': r'\sigma',
    u'τ': r'\tau',
    u'φ': r'\phi',
    u'χ': r'\xi',
    u'ψ': r'\psi',
    u'ω': r'\omega',

    u'Α': r'A',
    u'Β': r'B',
    u'Γ': r'\Gamma',
    u'Δ': r'\Delta',
    u'Ε': r'E',
    u'Ζ': r'Z',
    u'Η': r'H',
    u'Θ': r'\Theta',
    u'Ι': r'I',
    u'Κ': r'K',
    u'Λ': r'\Lambda',
    u'Μ': r'M',
    u'Ν': r'N',
    u'Ο': r'O',
    u'Π': r'\Pi',
    u'Ρ': r'P',
    u'Σ': r'\Sigma',
    u'Τ': r'T',
    u'Φ': r'\Phi',
    u'Χ': r'X',
    u'Ψ': r'\Psi',
    u'Ω': r'\Omega'    
}

punctuation = u"*[]().,@;+-/%:/∃∄∀⊆⊂⊦#≠=∈∉⊑⊏<>" + u"".join(math_sym.keys())

def latex(unicode, math_mode=False, unicode_only=False):
    """
    Convert a unicode string into a latex encoding of the same
    string, using math symbols where needed.
    """
    esc_sym = "{}_#&"

    if unicode == u"...":
        return r"\ldots"

    def map_uc(uc):
        if unicode_only and ord(uc) < 128:
            return uc.encode('ascii')
        if uc in math_sym or uc in alphabeta:
            ms = math_sym[uc] if uc in math_sym else alphabeta[uc]
            if not math_mode: return "$" + ms + "$"
            return ms+" "
        ac = uc.encode('ascii')
        if ac in esc_sym:
            return "\\" + ac
        return ac

    return "".join([map_uc(uc) for uc in unicode])

class UserError(Exception):
    pass

class DoubleNonTerminal(UserError):
    def __init__(self, nt):
        UserError.__init__(self, u"Non-terminal '%s' is defined twice!" % nt)

def set_options(self, options, allowed_options):
    for name, values in options.items():
        if name in allowed_options:
            setattr(self, name, values)
        else:
            raise UserError("Undefined option: '%s' Legal options: '%s'" % (
                    name, ",".join(allowed_options)))

class Grammar:
    def __init__(self, options, nts):
        self.nts = nts
        self.terminals = []
        self.substs = []
        set_options(self, options, ['terminals', 'substs'])

    def dump(self):
        res = u"\n".join([n.dump() for n in self.nts])
        print res.encode('utf8')
        
    def all_nts(self):
        result = []
        for nt in self.nts: result.extend(nt.names)
        return result

    def postprocess(self):
        all_nts = self.all_nts()
        
        # Check that all nts have a unique name:
        for nt in all_nts:
            if all_nts.count(nt) != 1:
                raise DoubleNonTerminal(nt)
        
        for n in self.nts:
            n.postprocess(all_nts, self.terminals, u" ".join(self.substs))

    def latex_rows(self):
        res = []
        res.append("\\begin{tabular}{llll}")
        for nt in self.nts:
            nt.add_latex_rows(res)
        res.append("\\end{tabular}")
        return res

class TypeSystem:
    def __init__(self, judgements):
        self.judgements = judgements
        self.extra_nts = []
        self.terminals = []
        self.substs = []

    def set_options(self, options):
        set_options(self, options, ["extra_nts", "terminals", "substs"])

    def dump(self):
        res = u"\n".join([n.dump() for n in self.judgements])
        print res.encode('utf8')

    def postprocess(self, gr):
        all_nts = gr.all_nts() + self.extra_nts
        for judgement in self.judgements:
            judgement.postprocess(
                all_nts, 
                gr.terminals + self.terminals, 
                u" ".join(gr.substs + self.substs))
            
    def latex_rows(self):
        middle = [r"\\", r"\\"] #r"\\[.5cm]", r"\hline", r"\\"]

        res = [r"\begin{tabular}{l}"]

        for idx, j in enumerate(self.judgements):
            if idx > 0: res += middle
            j.add_latex_rows(res)

        res += [r"\end{tabular}"]

        return res

class TypeJudgement:
    def __init__(self, heading, rules):
        self.heading = heading
        self.rules = rules

    def dump(self):
        return u"[%s]\n%s\n" % (
            self.prototype.dump(),
            "\n\n".join([r.dump() for r in self.rules]))

    def postprocess(self, *args):
        for rule in self.rules:
            rule.postprocess(*args)

    def add_latex_rows(self, res):
        if self.heading:
            res += [r"\hline",
                    r"\emph{%s}" % (latex(self.heading),),
                    r"\\", r"\\"] # \hline \\"]
        res += [r"\begin{mathpar}"]
        for idx, r in enumerate(self.rules):
            if idx > 0: res.append(r"\and")
            r.add_latex_rows(res)
        res += [r"\end{mathpar}"]

class TypeRule:
    def __init__(self, name, premises, conclusions):
        self.name = name
        self.premises = premises
        self.conclusions = conclusions

    def dump(self):
        return u"%s:\n%s\n%s\n%s\n" % (
            self.name,
            "\n".join(["    " + p.dump() for p in self.premises]),
            "    " + "-" * 30,
            "\n".join(["    " + p.dump() for p in self.conclusions]))

    def postprocess(self, *args):
        for o in (self.premises + self.conclusions):
            o.postprocess(*args)

    def add_latex_rows(self, res):
        lname = latex(self.name, math_mode=True)
        res.append(r"\inferrule[%s]" % (lname,))
        self.render_prods(res, self.premises)
        self.render_prods(res, self.conclusions)

    def render_prods(self, res, list):
        if list:
            res.append(r"{")
            for p in list:
                term = r" \\" if not p.newline else r" \\\\"
                res.append("    " + p.latex_str() + term)
            res.append("}")
        else:
            res.append("{}")    

class Nonterminal:
    def __init__(self, names, prods):
        self.names = names
        self.prods = prods

    def dump(self):
        names = u",".join(self.names)
        res = u"%s := %s\n" % (names, self.prods[0].dump())
        ind = u" " * len(names)
        for prod in self.prods[1:]:
            res += u"%s |  %s\n" % (ind, prod.dump())
        return res

    def postprocess(self, *args):
        for p in self.prods:
            p.postprocess(*args)

    def add_latex_rows(self, res):
        prods = self.prods
        if prods:
            names = u",".join(self.names)
            res.append("$%s$ & := & " % (latex(names, math_mode=True),))

            for prod in prods[:-1]:
                text = "$"+prod.latex_str()+"$"
                if prod.newline:
                    text += r" & $%s$ \\ & $\mid$ &" % (
                        prod.label_latex_str())
                else:
                    text += r" $\mid$ "
                res.append(text)

            last_prod = prods[-1]
            res.append(r"$%s$ & $%s$\\" % (
                last_prod.latex_str(), last_prod.label_latex_str()))

class Production:
    def __init__(self, items):
        self.items = items
        self.newline = None
        self.label = u""
        
    def set_label(self, label):
        self.label = label
        self.newline = True
            
    def dump(self):
        items_str = u",".join([i.dump() for i in self.items])
        if self.newline:
            return "%s (Label: %s)" % (items_str, self.label)
        return items_str

    def postprocess(self, *args):
        for i in self.items:
            i.postprocess(*args)

    def latex_str(self):
        return "".join([i.latex_str() for i in self.items])

    def label_latex_str(self):
        return latex(self.label)

class ItemSpace:
    def __init__(self):
        pass
    
    def dump(self):
        return u" "
    
    def postprocess(self, *args):
        return

    def latex_str(self):
        return r"\textrm{ }" #r"\:"

class ItemPunctuation:
    def __init__(self, text):
        self.text = text

    def dump(self):
        if self.text in reserved_chars:
            return u"\\%s" % (self.text,)
        return self.text

    def postprocess(self, *args):
        return

    def latex_str(self):
        return latex(self.text, math_mode=True)

class ItemSubscript:
    def __init__(self, (item1, item2)):
        self.items = [item1, item2]

    def dump(self):
        return u'%s_%s' % (self.items[0].dump(), self.items[1].dump())

    def postprocess(self, *args):
        for i in self.items: i.postprocess(*args)
    
    def latex_str(self):
        litems = [i.latex_str() for i in self.items]
        return r"%s_{%s}" % (litems[0], litems[1])

class ItemSupscript:
    def __init__(self, (item1, item2)):
        self.items = [item1, item2]

    def dump(self):
        return u'%s^%s' % (self.items[0].dump(), self.items[1].dump())

    def postprocess(self, *args):
        for i in self.items: i.postprocess(*args)
    
    def latex_str(self):
        litems = [i.latex_str() for i in self.items]
        return r"%s^{%s}" % (litems[0], litems[1])

class ItemPrime:
    def __init__(self, (item, primes)):
        self.item = item
        self.primes = primes

    def dump(self):
        return u"%s%s" % (self.items.dump(), self.primes)

    def postprocess(self, *args):
        self.item.postprocess(*args)
    
    def latex_str(self):
        return r"%s%s" % (self.item.latex_str(), self.primes)

class ItemKeyword:
    def __init__(self, text):
        self.text = text
    
    def dump(self):
        return '"%s"' % (self.text,)
    
    def postprocess(self, *args):
        return
    
    def latex_str(self):
        return r"\texttt{%s}" % latex(self.text, math_mode=False)
    
class ItemLiteral:
    def __init__(self, text):
        self.text = text
    
    def dump(self):
        return '`%s`' % (self.text,)
    
    def postprocess(self, all_nts, terminals, u_substs):
        u_regexp = ur"`" + re.escape(self.text) + ur"`=`([^`]*)`"
        mo = re.search(u_regexp, u_substs)
        if mo: 
            self.text = mo.group(1)
    
    def latex_str(self):
        return self.text.encode('ascii') + " "

class ItemIdentifer:
    def __init__(self, text):
        self.text = text
        self.is_nt = False
        self.subscripted_from = None

    def dump(self):
        return self.text

    def postprocess(self, all_nts, terminals, u_substs):
        self.is_nt = (self.text in all_nts)
        self.is_t = (self.text in terminals)

    def latex_str(self):
        if not self.is_t:
            return latex(self.text, math_mode=True)
        # print terminals in tt font:
        return r"\texttt{%s}" % (latex(self.text, math_mode=False),)
    
class RepeatedItem:
    def __init__(self, items):
        self.items = items
        
    def dump(self):
        return u"{%s}" % (u",".join([i.dump() for i in self.items]),)

    def postprocess(self, *args):
        for i in self.items: i.postprocess(*args)

    def latex_str(self):
        item_str = "".join([i.latex_str() for i in self.items])
        return r"\overline{%s}" % (item_str,)

