"""

This file does not run unittests: rather it is a class that helps unittests
intercept the IR at various stages of compilation.

When it is created it is supplied with a basename, such as
"tests/001.ppc".  For each intercepted stage, we then append
".pass.out" to that name and write out the IR at that point.

"""

from cStringIO import StringIO
import sys, log, chelpers

from pynto import dump

class UnitTestDumper:

    def __init__ (self, comp, basename, mdname):
        self.basename = basename
        self.namers = dump.Namers ()
        self.comp = comp
        comp.set_monitor (self)
        self.mdname = mdname

        self.passes = ('basic',
                       'refine',
                       'lowlevel',
                       'codegen',
                       'output')
        pass

    # --------------------------------------------------------------
    # Top level routine invoked with a module to be jitted and
    # tested.  Generated the various output files.

    def test (self, modobj):
        self.comp.jit (modobj)

        outfile = self._open_out_file ("output", 'w')

        if "output" in self.comp.dump_pass:
            self.comp.log.set_level (self.comp.dump_level)
            pass

        ind = self.comp.log.indent ("Test executions")
        
        # Now that compilation is complete, we try to invoke all of
        # the methods defined in the module in alphabetical order.
        names = dir (modobj)
        names.sort ()
        for name in names:
            subobj = getattr (modobj, name)
            if callable (subobj) and name[0] != '_':
                self.comp.log.high ('Invoking %s (subobj=%s 0x%x)',
                                    name, subobj, id (subobj))
                outfile.write ('Invoking %s:\n' % name)
                try:
                    res = subobj ()
                    outfile.write ('  Result: %s\n' % repr (res))
                except:
                    err = sys.exc_info()[1]
                    outfile.write ('  Exception: %s\n' % repr (err))
                    pass
                pass
            pass

        outfile.close ()
        del ind        
        self.comp.log.set_level (log.NOTHING)
        pass

    # ------------------------------------------------------------
    # Utilities

    def _fname (self, passnm, ext):
        return "%s.%s.%s.%s" % (self.basename, passnm, self.mdname, ext)

    def _open_out_file (self, name, how):
        assert name in self.passes
        return open (self._fname (name, "out"), how)

    def outputs (self):
        return map (lambda p: self._fname(p, "out"), self.passes)

    def refs (self):
        return map (lambda p: self._fname(p, "ref"), self.passes)

    # Basic and Refine are both given a KNode to dump

    def basic (self):
        outfile = self._open_out_file ("basic", 'w')
        dump.dump_ktree (outfile, self.namers, 0, self.comp)
        outfile.close ()
        pass

    def refine (self):
        outfile = self._open_out_file ("refine", 'w')
        dump.dump_ktree (outfile, self.namers, 0, self.comp)
        outfile.close ()
        pass

    def low_level_start (self):
        # Overwrite lowlevel file: 
        outfile = self._open_out_file ("lowlevel", 'w')
        outfile.close ()
        return
        
    def low_level_emitter (self, funcnode, emitter):
        # Append to existing lowlevel file:
        outfile = self._open_out_file ("lowlevel", 'a')
        outfile.write ("\n")
        outfile.write ("-"*70)
        outfile.write ("\n")
        dump.dump_emitter (outfile,
                           0,
                           funcnode,
                           emitter,
                           self.namers)
        outfile.close ()
        pass

    def low_level_cfg (self, funcnode, cfg):
        # Append to existing lowlevel file:
        outfile = self._open_out_file ("lowlevel", 'a')
        outfile.write ("\n")
        outfile.write ("-"*70)
        outfile.write ("\n")
        dump.dump_ll_cfg (outfile,
                          0,
                          self.comp,
                          funcnode,
                          cfg,
                          self.namers)
        outfile.close ()
        pass

    def low_level_start_emitting (self, funcnode):
        outfile = self._open_out_file ("lowlevel", 'a')
        outfile.write ("\n")
        outfile.write ("-"*70)
        outfile.write ("\n")
        outfile.write ("\n\Instructions Emitted For Function %s:\n\n"
                       % self.namers.knode (funcnode))
        self.emit_file = outfile
        pass
    
    def low_level_emit (self, oper):
        # Append to existing lowlevel file:
        dests = [ x.assignment for x in oper.dests ]
        sources = [ x.assignment for x in oper.sources ]
        opercl = oper.clone (dests, sources)
        dump.dump_opers (self.emit_file, 0, [opercl], self.namers)
        pass

    def low_level_end_emitting (self, funcnode):
        if self.emit_file: self.emit_file.close ()
        self.emit_file = None
        pass

    def low_level_bytes (self, funcnode, finfo):
        if not dump.canonical:
            # In canonical mode we can't exactly dump the actual bytes
            # without normalizing them first, and since I don't quite
            # know how to normalize them at the moment, we'll just
            # have to skip it.
            outfile = self._open_out_file ("lowlevel", 'a')
            outfile.write ("\n")
            outfile.write ("-"*70)
            outfile.write ("\n")
            outfile.write ("\n\Bytes Generated For Function %s:\n\n"
                           % self.namers.knode (funcnode))
            outfile.write (str (finfo.buffer))
            pass
        pass
    
    def codegen (self):
        outfile = self._open_out_file ("codegen", 'w')
        dump.dump_ktree (outfile, self.namers, 0, self.comp)
        outfile.close ()
        pass

    pass
