import unittest
import sys
import path
import imp

from cStringIO import StringIO

from pynto import stacksim, dump, compiler, log, ppc, unit

debug_mode = 0
trace_mode = 0
dump_pass = []
debug_level = 0
only_tests = []

# All of the above tests will be conducted with each of the following backends:
backends = ( 'ppc', )

class AllTest (unittest.TestCase):
    def __init__ (self, pypath, backend):
        unittest.TestCase.__init__ (self, "testMe")
        self.pypath = pypath
        self.lastpass = "(all)"
        self.backend = backend
        return

    def setUp (self):
        # Load the module at pypath:
        suffix = filter (lambda x: x[0] == ".py", imp.get_suffixes())[0]
        self.modobj = imp.load_module ("pynto.tests." + self.pypath.namebase,
                                       self.pypath.open (),
                                       self.pypath,
                                       suffix)

        # Create a compiler and set its debug options
        self.comp = compiler.Compiler ()
        self.comp.set_dump_options (dump_pass, debug_level)
        self.comp.set_trace_options (trace_mode)

        # Create an appropriate backend and give it to the compiler
        md = globals()[self.backend].Definition(self.comp)
        self.comp.set_machine_desc (md)

        # Finally, create the unit tester
        testbase = self.pypath.parent / self.pypath.namebase
        self.unit = unit.UnitTestDumper (self.comp,
                                         testbase,
                                         self.backend)
        return

    def testMe (self):
        # Remove old outputs:
        for outpath in self.unit.outputs ():
            try:
                path.path (outpath).remove ()
            except OSError:
                pass
            pass

        # Launch the compiler
        self.unit.test (self.modobj)
                    

        for outpath, refpath in zip (map (path.path, self.unit.outputs()),
                                     map (path.path, self.unit.refs())):

            self.lastpass = outpath
            outtext = outpath.text ()
            try: reftext = refpath.text ()
            except IOError: reftext = "<blah>"
            if not debug_mode:
                self.assertEqual (outtext,
                                  reftext,
                                  "%s not same as %s" % (outpath,refpath))
                pass
            pass
            
        return
        
    def id (self):
        return "All:%s" % (self.pypath.name)

    def shortDescription (self):
        return "%s: Generating '%s'" % (self.pypath.name, self.lastpass)

    pass

def add_tests (suite):
    modpath = path.path (__file__)
    files = (modpath.parent / "tests").glob ('*.py')
    for file in files:
        if only_tests and not file.namebase in only_tests: continue
        for backend in backends:
            suite.addTest (AllTest (file, backend))
            pass
        pass
    pass

def suite ():
    suite = unittest.TestSuite ()
    add_tests (suite)
    return suite

def set_debug_mode ():
    global debug_mode
    debug_mode = 1
    dump.canonical = 0
    pass
        
if __name__ == "__main__":

    if "--debug-mode" in sys.argv:
        sys.argv.remove ("--debug-mode")
        set_debug_mode ()
        pass

    if "--trace-mode" in sys.argv:
        sys.argv.remove ("--trace-mode")
        trace_mode = 2
        pass

    while "--dump-pass" in sys.argv:
        idx = sys.argv.index ("--dump-pass")
        pname = sys.argv[idx+1]

        set_debug_mode ()
        dump_pass += pname.split (',')

        del sys.argv[idx]
        del sys.argv[idx]
        pass

    if "--dump-level" in sys.argv:
        idx = sys.argv.index ("--dump-level")
        plevel = sys.argv[idx+1]

        debug_level = int (plevel)

        del sys.argv[idx]
        del sys.argv[idx]
        pass

    while "-j" in sys.argv:
        idx = sys.argv.index ("-j")
        testnm = sys.argv[idx+1]

        only_tests += (testnm,)

        del sys.argv[idx]
        del sys.argv[idx]
        pass
        

    unittest.main(defaultTest='suite')

