#!/bin/env python
import sys,glob,unittest,os,re
start = re.compile(r'/\*TEST')
end = re.compile(r'TEST\*/')

class SkipException(Exception):
  pass

class Tester(unittest.TestCase):
   def compare(self,x,y):
        try:
          self.assertAlmostEqual(x,y)
        except TypeError:
          self.assertEqual(x,y)
        return 

   def similar(self,A,B):
        Atokens = A.split()
        Btokens = B.split()
        self.failUnless(len(Atokens) == len(Btokens),"output doesn't match")
        map(self.compare,Atokens,Btokens)
        return True

   def compile_check(self,path):
        skip = 'SKIP=0\n'
        source = open(path).read()
        match = start.search(source)
        if match is None:
          open('temp.c','w').write(source)
          print
          print path
          cstat = os.system('./mpicc -o temp temp.c')
          pipe = os.popen('./temp')
          try:
            out = pipe.read()
            stat = pipe.close()
          except KeyboardInterrupt:
            out = ''
            stat = 1
          if stat is None: stat = 0
          if stat != 0:
             skip='SKIP=1\n'
             stat = 0
             out = ''
          fix = '''
/*TEST
%s
PATH=%r
CCFLAGS=""
INPUT=""
OUTPUT=%r
STATUS=0
TEST*/'''%(skip,path,out)
          print out
          print fix
          print 'mv %s_ %s; ./regressiontest'%(path,path)
          open(path+'_','w').write(fix+'\n\n'+source)
          os._exit(0)
        self.failIf(match is None,"Cannot find /*TEST in "+path)
        start_i = match.end()
        match = end.search(source)
        self.failIf(match is None,"Cannot find TEST*/ in "+path)
        end_i = match.start()
        block = source[start_i:end_i]
        CCFLAGS=""
        INPUT=""
        OUTPUT=""
        STATUS=""
        SKIP=0
        try:
          exec(block)
        except:
          print block
        if SKIP: raise SkipException,path

        exe = os.path.splitext(path)[0] + '.exe'
        open('temp.c','w').write(source)
        cstat = os.system('./mpicc -o %s temp.c'%exe)
        self.failUnless(cstat == 0,"Compilation failed")
        open('temp.in','w').write(INPUT)
        try:
          os.unlink('temp.out')
        except:
          pass
        rstat = os.system('%s < temp.in > temp.out'%exe)
        if STATUS:
          self.failUnless(rstat != 0,"Got the wrong status (%s) from run"%rstat)
        else:
          self.failUnless(rstat == 0,"Got the wrong status (%s) from run"%rstat)
        out = open('temp.out').read()
        self.failUnless(self.similar(out,OUTPUT),"Output didn't match")

        return

if __name__ == '__main__':
  paths = sys.argv[1:]
  del sys.argv[1:]
  if not paths: paths = glob.glob('tests/*.c')
  klass = 'class Test(Tester):\n'
  for path in paths:
     directory,base = os.path.split(path)
     name,ext = os.path.splitext(base)
     name = name.replace('.','_')
     klass += '''
  def test_%s(self):
    try:
      self.compile_check(%r)
    except SkipException,path:
      print 'SKIPPING test',path
'''%(name,path)
  exec(klass)
  unittest.main()
