import sys, optparse, difflib, traceback, os
from cStringIO import StringIO
from axle.main import AxleInterpreter, AxleError

def expected_output(filenm):
    "Determines expected output from 'filenm'"
    lines = list(open(filenm).readlines())
    expected_output = []
    for line in reversed(lines):
        if line == "# EXPECTED OUTPUT\n": break
        if not line.startswith("# "): break
        if not line.endswith("\n"): line += "\n"
        expected_output.append(line[2:])
    return "".join(reversed(expected_output))

def run_test(filenm):
    "Run test and return True if it passed "

    errfile = filenm + ".err"
    dbgfile = filenm + ".dbg.txt"

    errout = open(errfile, 'w')
    debugout = open(dbgfile, 'w')

    stdout = StringIO()
    interp = AxleInterpreter(stdout, debugout=debugout)

    try:
        interp.run(["--debug", "--default-encoding", "utf8", filenm])
        expout = expected_output(filenm)
        actout = stdout.getvalue()
        if actout  == expout: 
            errout.close()     # this way file always exists...
            debugout.close()
            os.remove(errfile) # ...so os.remove doesn't fail
            return True
    except:
        actout = stdout.getvalue()
        errout.write("ERROR REPORT FOR %s\n" % filenm)
        errout.write("\n")
        errout.write("Encountered exception:\n")
        traceback.print_exc(None, errout)
        errout.write("\n")
        errout.write("Actual output (so far):\n")
        errout.write(actout)
        errout.write("\n")
    else:
        errout.write("ERROR REPORT FOR %s\n" % filenm)
        errout.write("\n")
        errout.write("Expected output:\n")
        errout.write(expout)
        errout.write("\n")
        errout.write("Actual output:\n")
        errout.write(actout)
        errout.write("\n")
        errout.write("Diff:\n")
        diffs = list(difflib.unified_diff(
                expout.split('\n'), actout.split('\n'),
                "expected", "actual"))
        errout.write("\n".join(diffs))
    errout.close()    
    debugout.close()
    return False

def main(all_args):
    parser = optparse.OptionParser()
    (options, args) = parser.parse_args(all_args)

    failed_tests = [filenm for filenm in args if not run_test(filenm)]
    if not failed_tests:
        print "All %d test(s) passed." % len(args)
    else:
        print "%d test(s) FAILED out of %d total test(s):" % (
            len(failed_tests), len(args))
        for filenm in failed_tests:
            print "\t%s" % filenm

if __name__ == "__main__":
    main(sys.argv[1:])
