import unittest
from copy import deepcopy
import subprocess
import os
import random

from comsdk.graph import *
from comsdk.edge import *
from comsdk.communication import *

def dummy_edge(data):
    pass

def increment_a_edge(data):
    data['a'] += 1

def increment_a_array_edge(data):
    for i in range(len(data['a'])):
        data['a'][i] += 1

def increment_b_edge(data):
    data['b'] += 1

def decrement_a_edge(data):
    data['a'] -= 1

def write_a_edge(data):
    a_filename = os.path.join(data['__CURRENT_WORKING_DIR__'], 'tests/square_test_dir/input/a.dat')
    with open(a_filename, 'w') as f:
        f.write(str(data['a']))
    data['a_file'] = a_filename

def load_b_edge(data):
    b_filename = os.path.join(data['__WORKING_DIR__'], data['b_file'])
    with open(b_filename, 'r') as f:
        data['b'] = int(f.read())

def nonzero_predicate(data):
    return True if data['a'] != 0 else False

def positiveness_predicate(data):
    return True if data['a'] > 0 else False

def nonpositiveness_predicate(data):
    return True if data['a'] <= 0 else False

def print_exception(exc_data, data):
    print('exception data: {}'.format(exc_data))
    print('current state of data: {}'.format(data))

def print_stdout(data, stdout_lines):
#    print(stdout)
    return {}

def check_task_finished(data, stdout_lines):
    '''
    Example:
    job-ID  prior   name       user         state submit/start at     queue                          slots ja-task-ID 
    -----------------------------------------------------------------------------------------------------------------
    663565 0.00053 RT700-tran scegr        r     09/19/2018 23:51:22 24core-128G.q@dc2s2b1a.arc3.le    24        
    663566 0.00053 RT800-tran scegr        r     09/19/2018 23:51:22 24core-128G.q@dc3s5b1a.arc3.le    24        
    663567 0.00053 RT900-tran scegr        r     09/20/2018 00:00:22 24core-128G.q@dc4s2b1b.arc3.le    24        
    663569 0.00053 RT1000-tra scegr        r     09/20/2018 00:05:07 24core-128G.q@dc1s1b3d.arc3.le    24 
    '''
    job_finished = True
    for line in stdout_lines[2:]:
        items = line.split()
        if int(items[0]) == data['job_ID']:
            job_finished = False
    return {'job_finished': job_finished}

def set_job_id(data, stdout_lines):
    return {'job_ID': int(stdout_lines[0].split()[2])} # example: 'Your job 664989 ("fe_170.310.sh") has been submitted'

def _create_data_from_dict(d):
    data = deepcopy(d)
    data['__CURRENT_WORKING_DIR__'] = os.getcwd()
    if not '__WORKING_DIR__' in data:
        data['__WORKING_DIR__'] = data['__CURRENT_WORKING_DIR__']
    return data

class GraphGoodCheck(unittest.TestCase):
    initial_conditions = range(-10, 10)

    @classmethod
    def setUpClass(cls):
        command_line = 'cd tests/square; g++ square.cpp -o square'
        subprocess.call([command_line], shell=True)

        local_host = Host()
        local_host.add_program('square', os.path.join(os.getcwd(), 'tests', 'square'))
        cls.local_comm = LocalCommunication(local_host)
        cls.ssh_host = 'arc3.leeds.ac.uk'
        cls.ssh_cores = 24
        cls.ssh_user = 'mmap'
        cls.ssh_pswd = '1bdwbzsc'
        cls.ssh_path_to_tests = '/home/home01/mmap/tests'
        remote_host = RemoteHost(ssh_host='arc3.leeds.ac.uk', 
                                 cores=24,
                                 )
        remote_host.add_program('square', '{}/square'.format(cls.ssh_path_to_tests))
        remote_host.add_program('qsub')
        remote_host.add_program('qstat')
        cls.ssh_comm = SshCommunication(remote_host, 
                                        username=cls.ssh_user,
                                        password=cls.ssh_pswd,
                                        )
        cls.ssh_comm.mkdir('{}/square_test_dir'.format(cls.ssh_path_to_tests))

    @classmethod
    def tearDownClass(cls):
        aux.remove_if_exists('tests/square_test_dir/input/a.dat')
        aux.remove_if_exists('tests/square_test_dir/output/b.dat')
        cls.ssh_comm.rm('{}/square_test_dir'.format(cls.ssh_path_to_tests))

    def test_trivial_serial_graph(self):
        initial_datas = [{'a': ic} for ic in self.initial_conditions]
        invalid_initial_datas = [{'a': ic} for ic in (-1, 0)]
        initial_state, term_state, correct_outputs = self._get_trivial_serial_graph(initial_datas)
        self._run_graph(initial_state, term_state, initial_datas, invalid_initial_datas, correct_outputs)

    def test_trivial_parallel_graph(self):
        initial_datas = [{'a': ic, 'b': ic} for ic in self.initial_conditions]
        invalid_initial_datas = [{'a': ic, 'b': ic} for ic in (-2, -1, 0)]
        initial_state, term_state, correct_outputs = self._get_trivial_parallel_graph(initial_datas)
        self._run_graph(initial_state, term_state, initial_datas, invalid_initial_datas, correct_outputs)

    def test_trivial_cycled_graph(self):
        initial_datas = [{'a': ic} for ic in self.initial_conditions]
        initial_state, term_state, correct_outputs = self._get_trivial_cycled_graph(initial_datas)
        self._run_graph(initial_state, term_state, initial_datas, (), correct_outputs)

    def test_complex_graph_made_from_trivial_ones_using_dummy_edges(self):
        '''
        serial graph + parallel graph + cycled graph
        '''
        initial_datas = [{'a': ic, 'b': ic} for ic in self.initial_conditions]
        invalid_initial_datas = [{'a': ic, 'b': ic} for ic in (-4, -3, -2, -1, 0)]
        s_1, s_2, correct_outputs = self._get_trivial_serial_graph(initial_datas)
        s_3, s_4, correct_outputs = self._get_trivial_parallel_graph(correct_outputs)
        s_5, s_6, correct_outputs = self._get_trivial_cycled_graph(correct_outputs)
        s_2.connect_to(s_3, edge=Edge(dummy_predicate, dummy_edge))
        s_4.connect_to(s_5, edge=Edge(dummy_predicate, dummy_edge))
        self._run_graph(s_1, s_6, initial_datas, invalid_initial_datas, correct_outputs)

    def test_trivial_serial_graph_with_subgraph(self):
        initial_datas = [{'a': ic} for ic in self.initial_conditions]
        initial_state, term_state, correct_outputs = self._get_trivial_serial_graph_with_subgraph(initial_datas)
        self._run_graph(initial_state, term_state, initial_datas, (), correct_outputs)

    def test_trivial_parallel_graph_with_subgraph(self):
        initial_datas = [{'a': ic, 'b': ic} for ic in self.initial_conditions]
        initial_state, term_state, correct_outputs = self._get_trivial_parallel_graph_with_subgraph(initial_datas)
        self._run_graph(initial_state, term_state, initial_datas, (), correct_outputs)

    def test_complex_graph_made_from_trivial_ones_using_subgraphs(self):
        '''
        serial graph + parallel graph + cycled graph
        '''
        initial_datas = [{'a': ic, 'b': ic} for ic in self.initial_conditions]
        invalid_initial_datas = [{'a': ic, 'b': ic} for ic in (-4, -3, -2, -1, 0)]
        s_1, s_2, correct_outputs = self._get_trivial_serial_graph(initial_datas)
        s_3, s_4, correct_outputs = self._get_trivial_parallel_graph(correct_outputs)
        s_5, s_6, correct_outputs = self._get_trivial_cycled_graph(correct_outputs)
        s_2.replace_with_graph(Graph(s_3, s_4))
        s_4.replace_with_graph(Graph(s_5, s_6))
        print(correct_outputs)
        self._run_graph(s_1, s_6, initial_datas, invalid_initial_datas, correct_outputs)

    def test_trivial_graph_with_implicit_parallelization(self):
        '''
        s_1 -> s_2 -> s_3, with dummy edges
        s_2 = s_11 -> s_12 -> s_13, with +1 edges for a
        three implicitly parallel branches appear instead of s_2
        '''
        initial_datas = [{'a': [ic**i for i in range(1, 4)]} for ic in self.initial_conditions]
        initial_state, term_state, correct_outputs = self._get_trivial_graph_with_implicit_parallelization(initial_datas)
        self._run_graph(initial_state, term_state, initial_datas, (), correct_outputs)

    def test_cycled_graph_with_implicit_parallelization(self):
        random_neg_ics = [[random.randrange(-20, -3) for _ in range(3)] for _ in range(10)]
        initial_datas = [{'a': random_neg_ic} for random_neg_ic in random_neg_ics]
        #initial_datas = [{'a': [-4, -12]},]
        #initial_datas = [{'a': [-3, -3]},]
        initial_state, term_state, correct_outputs = self._get_cycled_graph_with_implicit_parallelization(initial_datas)
        self._run_graph(initial_state, term_state, initial_datas, (), correct_outputs)

    def test_trivial_graph_with_external_local_program(self):
        initial_datas = [{'a': ic, '__WORKING_DIR__': os.path.join(os.getcwd(), 'tests', 'square_test_dir', 'output')} for ic in self.initial_conditions]
        initial_state, term_state, correct_outputs = self._get_trivial_graph_with_external_local_program(initial_datas)
        self._run_graph(initial_state, term_state, initial_datas, (), correct_outputs)

    def test_trivial_graph_with_external_remote_program(self):
        initial_datas = [{'a': ic,
                          '__WORKING_DIR__': os.path.join(os.getcwd(), 'tests', 'square_test_dir', 'output'),
                          '__REMOTE_WORKING_DIR__': '{}/square_test_dir'.format(self.ssh_path_to_tests)}
                          for ic in self.initial_conditions]
        initial_state, term_state, correct_outputs = self._get_trivial_graph_with_external_remote_program(initial_datas)
        self._run_graph(initial_state, term_state, initial_datas, (), correct_outputs)

    def test_trivial_graph_with_external_remote_program_using_grid_engine(self):
        initial_datas = [{'a': ic,
                          'user': self.ssh_user,
                          'cores_required': 12,
                          'time_required': '12:00:00',
                          'qsub_script_name': 'square.sh',
                          '__WORKING_DIR__': os.path.join(os.getcwd(), 'tests', 'square_test_dir', 'output'),
                          '__REMOTE_WORKING_DIR__': '{}/square_test_dir'.format(self.ssh_path_to_tests)}
                          for ic in self.initial_conditions[0:2]]
        initial_state, term_state, correct_outputs = self._get_trivial_graph_with_external_remote_program_using_grid_engine(initial_datas)
        self._run_graph(initial_state, term_state, initial_datas, (), correct_outputs)

    def _get_trivial_serial_graph(self, initial_conditions):
        '''
        s_1 -> s_2 -> s_3,
        p_12 = p_23 := a not 0
        f_12 = f_23 := a + 1
        '''
        s_1 = State('serial_s_1')
        s_2 = State('serial_s_2')
        s_3 = State('serial_s_3')
        s_1.connect_to(s_2, edge=Edge(nonzero_predicate, increment_a_edge))
        s_2.connect_to(s_3, edge=Edge(nonzero_predicate, increment_a_edge))
        correct_outputs = []
        for ic in initial_conditions:
            output = _create_data_from_dict(ic)
            output['a'] += 2
            correct_outputs.append(output)
        return s_1, s_3, correct_outputs

    def _get_trivial_parallel_graph(self, initial_conditions):
        '''
        s_1 -> s_2 -> s_3 ---------> s6s
            -> s_4 -> s_4_1 -> s_5 
                   -> s_4_2            
        p_12 = p_24 = p_13 = p_34 := a not 0
        f_12 = f_24 := a + 1
        f_13 = f_34 := b + 1
        '''

        s_1 = State('nonparallel_s_1')
        s_2 = State('parallel_s_2')
        s_3 = State('parallel_s_3')
        s_4 = State('parallel_s_4')
        s_4_1 = State('parallel_s_4_1')
        s_4_2 = State('parallel_s_4_2')
        s_5 = State('parallel_s_5')
        s_6 = State('nonparallel_s_6')
        s_1.connect_to(s_2, edge=Edge(nonzero_predicate, increment_a_edge))
        s_2.connect_to(s_3, edge=Edge(nonzero_predicate, increment_a_edge))
        s_3.connect_to(s_6, edge=Edge(nonzero_predicate, increment_a_edge))
        s_1.connect_to(s_4, edge=Edge(nonzero_predicate, increment_b_edge))
        s_4.connect_to(s_4_1, edge=Edge(nonzero_predicate, increment_b_edge))
        s_4.connect_to(s_4_2, edge=Edge(nonzero_predicate, increment_b_edge))
        s_4_1.connect_to(s_5,  edge=Edge(nonzero_predicate, increment_b_edge))
        s_4_2.connect_to(s_6,  edge=Edge(nonzero_predicate, increment_b_edge))
        s_5.connect_to(s_6, edge=Edge(nonzero_predicate, increment_b_edge))
        correct_outputs = []
        for ic in initial_conditions:
            output = _create_data_from_dict(ic)
            output['a'] += 3
            output['b'] += 3
            correct_outputs.append(output)
        return s_1, s_6, correct_outputs

    def _get_trivial_cycled_graph(self, initial_conditions):
        '''
        s_1 -> s_2 -> s_3
            <-
        p_12 := True
        p_23 := a > 0
        p_23 := a <= 0
        f_12 = f_23 = f_24 := a + 1
        '''

        s_1 = State('cycled_s_1')
        s_2 = State('cycled_s_2')
        s_3 = State('cycled_s_3')
        s_1.connect_to(s_2, edge=Edge(dummy_predicate, increment_a_edge))
        s_2.connect_to(s_3, edge=Edge(positiveness_predicate, increment_a_edge))
        s_2.connect_to(s_1, edge=Edge(nonpositiveness_predicate, increment_a_edge))
        correct_outputs = []
        for ic in initial_conditions:
            output = _create_data_from_dict(ic)
            if output['a'] >= 0:
                output['a'] += 2
            else:
                output['a'] = output['a']%2 + 2
            correct_outputs.append(output)
        return s_1, s_3, correct_outputs

    def _get_trivial_graph_with_external_local_program(self, initial_conditions):
        '''
        s_1 -> s_2 -> s_3 -> s_4,
        p_12 = p_23 = dummy
        f_12 = write(a) into a.dat
        f_23 = a**2
        f_34 = read b.dat from the working dir into b
        '''
        square_edge = ExecutableProgramEdge('square', self.local_comm,
                                            output_dict={'b_file': 'b.dat'},
                                            trailing_args_keys=('a_file',),
                                            )
        s_1 = State('external_s_1')
        s_2 = State('external_s_2')
        s_3 = State('external_s_3')
        s_4 = State('external_s_4')
        s_1.connect_to(s_2, edge=Edge(dummy_predicate, write_a_edge))
        s_2.connect_to(s_3, edge=square_edge)
        s_3.connect_to(s_4, edge=Edge(dummy_predicate, load_b_edge))
        correct_outputs = []
        for ic in initial_conditions:
            output = _create_data_from_dict(ic)
            output['a_file'] = os.path.join(os.getcwd(), 'tests/square_test_dir/input/a.dat')
            output['b'] = output['a']**2
            output['b_file'] = 'b.dat'
            correct_outputs.append(output)
        return s_1, s_4, correct_outputs

    def _get_trivial_graph_with_external_remote_program(self, initial_conditions):
        '''
        s_1 -> s_2 -> s_3 -> s_4 -> s_5 -> s_6,
        all predicates are dummy
        f_12 = write(a) into a.dat
        f_23 = upload a.dat into the working dir on remote
        f_34 = a**2
        f_45 = download b.dat from the working dir on remote to the working dir on local
        f_56 = read download b.dat from the working dir on local into b
        '''
        upload_edge = UploadOnRemoteEdge(self.ssh_comm,
                                         local_paths_keys=('a_file',),
                                         )
        square_edge = ExecutableProgramEdge('square', self.ssh_comm,
                                            output_dict={'b_file': 'b.dat'},
                                            trailing_args_keys=('a_file',),
                                            remote=True,
                                            )
        download_edge = DownloadFromRemoteEdge(self.ssh_comm,
                                               remote_paths_keys=('b_file',),
                                               )
        s_1 = State('remote_s_1')
        s_2 = State('remote_s_2')
        s_3 = State('remote_s_3')
        s_4 = State('remote_s_4')
        s_5 = State('remote_s_5')
        s_6 = State('remote_s_6')
        s_1.connect_to(s_2, edge=Edge(dummy_predicate, write_a_edge))
        s_2.connect_to(s_3, edge=upload_edge)
        s_3.connect_to(s_4, edge=square_edge)
        s_4.connect_to(s_5, edge=download_edge)
        s_5.connect_to(s_6, edge=Edge(dummy_predicate, load_b_edge))
        correct_outputs = []
        for ic in initial_conditions:
            output = _create_data_from_dict(ic)
            output['a_file'] = os.path.join(ic['__REMOTE_WORKING_DIR__'], 'a.dat')
            output['b'] = output['a']**2
            output['b_file'] = os.path.join(ic['__WORKING_DIR__'], 'b.dat')
            correct_outputs.append(output)
        return s_1, s_6, correct_outputs

    def _get_trivial_graph_with_external_remote_program_using_grid_engine(self, initial_conditions):
        '''
        s_1 -> s_2 -> s_3 -> s_4 -> s_5 -> s_6 -> s_7 -> s_8 -> s_9,
                             <->
        all predicates, expect p_66 and p_67, are dummy
        p_66 = job unfinished
        p_67 = job finished
        f_12 = write(a) into a.dat
        f_23 = upload a.dat into the working dir on remote
        f_34 = make up qsub script launching square
        f_45 = upload qsub script
        f_56 = send job (a**2) via qsub
        f_66 = check job finished via qstat
        f_67 = download b.dat from the working dir on remote to the working dir on local
        f_78 = read download b.dat from the working dir on local into b
        f_89 = filter out a_file, b_file, job_ID, qsub_script
        '''
        make_up_qsub_script_edge = QsubScriptEdge('square', self.local_comm, self.ssh_comm,
                                                  trailing_args_keys=('a_file',),
                                                  )
        upload_a_edge = UploadOnRemoteEdge(self.ssh_comm,
                                           local_paths_keys=('a_file',),
                                           )
        upload_qsub_script_edge = UploadOnRemoteEdge(self.ssh_comm,
                                                     local_paths_keys=('qsub_script',),
                                                     )
        qsub_edge = ExecutableProgramEdge('qsub', self.ssh_comm,
                                            trailing_args_keys=('qsub_script',),
                                            output_dict={'job_finished': False, 'b_file': 'b.dat'},
                                            stdout_processor=set_job_id,
                                            remote=True,
                                            )
        qstat_edge = ExecutableProgramEdge('qstat', self.ssh_comm,
                                            predicate=job_unfinished_predicate,
                                            io_mapping=InOutMapping(keys_mapping={'u': 'user', 'job_ID': 'job_ID'}),
                                            keyword_names=('u',),
                                            remote=True,
                                            stdout_processor=check_task_finished,
                                            )
        download_edge = DownloadFromRemoteEdge(self.ssh_comm,
                                               predicate=job_finished_predicate,
                                               remote_paths_keys=('b_file',),
                                               )
        s_1 = State('remote_s_1')
        s_2 = State('remote_s_2')
        s_3 = State('remote_s_3')
        s_4 = State('remote_s_4')
        s_5 = State('remote_s_5')
        s_6 = State('remote_s_6')
        s_7 = State('remote_s_7')
        s_8 = State('remote_s_8')
        s_9 = State('remote_s_9')
        s_1.connect_to(s_2, edge=Edge(dummy_predicate, write_a_edge))
        s_2.connect_to(s_3, edge=upload_a_edge)
        s_3.connect_to(s_4, edge=make_up_qsub_script_edge)
        s_4.connect_to(s_5, edge=upload_qsub_script_edge)
        s_5.connect_to(s_6, edge=qsub_edge)
        s_6.connect_to(s_6, edge=qstat_edge)
        s_6.connect_to(s_7, edge=download_edge)
        s_7.connect_to(s_8, edge=Edge(dummy_predicate, load_b_edge))
        def filter_data(data):
            del data['a_file']
            del data['b_file']
            del data['job_ID']
            del data['qsub_script']
        s_8.connect_to(s_9, edge=Edge(dummy_predicate, filter_data))
        correct_outputs = []
        for ic in initial_conditions:
            output = _create_data_from_dict(ic)
            output['b'] = output['a']**2
            output['job_finished'] = True
            correct_outputs.append(output)
        return s_1, s_9, correct_outputs

    def _get_trivial_serial_graph_with_subgraph(self, initial_conditions):
        '''
        s_1 -> s_2,
        where s_2 is replaced by s_1 -> s_2
        p_12 = p_23 := dummy
        f_12 := a + 1
        '''
        
        pred = Func(func=dummy_predicate)
        morph = Func(func=increment_a_edge)
        s_1 = State('s_1')
        s_2 = State('s_2')
        s_3 = State('s_3')
        s_1.connect_to(s_2, edge=Edge(pred, morph))
        s_2.connect_to(s_3, edge=Edge(pred, morph))
        sub_s_1 = State('sub_s_1')
        sub_s_2 = State('sub_s_2')
        sub_s_1.connect_to(sub_s_2, edge=Edge(pred, morph))
        s_2.replace_with_graph(Graph(sub_s_1, sub_s_2))
        correct_outputs = []
        for ic in initial_conditions:
            output = _create_data_from_dict(ic)
            output['a'] += 3
            correct_outputs.append(output)
        return s_1, s_3, correct_outputs

    def _get_trivial_parallel_graph_with_subgraph(self, initial_conditions):
        '''
        s_1 -> s_2 -> s_4
            -> s_3 ->
        where s_2 and s_3 is replaced by s_5 -> s_6 -> s_7
        all predicate are dummy
        f_12 = f_24 := a + 1
        f_13 = f_34 := b + 1
        f_56 = f_67 := a + 1
        '''

        asp = AllSelectionPolicy()
        s_1 = State('s_1', selection_policy=AllSelectionPolicy())
        s_2 = State('s_2')
        s_3 = State('s_3')
        s_4 = State('s_4')
        s_1.connect_to(s_2, edge=Edge(dummy_predicate, increment_a_edge))
        s_1.connect_to(s_3, edge=Edge(dummy_predicate, increment_b_edge))
        s_2.connect_to(s_4, edge=Edge(dummy_predicate, increment_a_edge))
        s_3.connect_to(s_4, edge=Edge(dummy_predicate, increment_b_edge))
        sub1_s_5 = State('s_2_sub_s_5')
        sub1_s_6 = State('s_2_sub_s_6')
        sub1_s_7 = State('s_2_sub_s_7')
        sub2_s_5 = State('s_3_sub_s_5')
        sub2_s_6 = State('s_3_sub_s_6')
        sub2_s_7 = State('s_3_sub_s_7')
        sub1_s_5.connect_to(sub1_s_6, edge=Edge(dummy_predicate, increment_a_edge))
        sub1_s_6.connect_to(sub1_s_7, edge=Edge(dummy_predicate, increment_a_edge))
        sub2_s_5.connect_to(sub2_s_6, edge=Edge(dummy_predicate, increment_a_edge))
        sub2_s_6.connect_to(sub2_s_7, edge=Edge(dummy_predicate, increment_a_edge))
        s_2.replace_with_graph(Graph(sub1_s_5, sub1_s_7))
        s_3.replace_with_graph(Graph(sub2_s_5, sub2_s_7))
        correct_outputs = []
        for ic in initial_conditions:
            output = _create_data_from_dict(ic)
            output['a'] += 6
            output['b'] += 2
            correct_outputs.append(output)
        return s_1, s_4, correct_outputs

    def _get_trivial_graph_with_implicit_parallelization(self, initial_conditions):
        '''
        s_1 -> s_2 -> s_4
            -> s_3 ->
        where s_2 and s_3 is replaced by s_5 -> s_6 -> s_7
        all predicate are dummy
        f_12 = f_24 := a + 1
        f_13 = f_34 := b + 1
        f_56 = f_67 := a + 1
        '''

        #asp = AllSelectionPolicy()
        sub_s_1 = State('sub_s_1', array_keys_mapping={'a': 'a'})
        sub_s_2 = State('sub_s_2')
        sub_s_3 = State('sub_s_3')
        subgraph = Graph(sub_s_1, sub_s_3)
        s_1 = State('s_1')
        s_2 = State('s_2')
        s_3 = State('s_3')
        s_1.connect_to(s_2, edge=Edge(dummy_predicate, dummy_edge))
        s_2.connect_to(s_3, edge=Edge(dummy_predicate, dummy_edge))
        sub_s_1.connect_to(sub_s_2, edge=Edge(dummy_predicate, increment_a_edge))
        sub_s_2.connect_to(sub_s_3, edge=Edge(dummy_predicate, increment_a_edge))
        s_2.replace_with_graph(subgraph)
        correct_outputs = []
        for ic in initial_conditions:
            output = _create_data_from_dict(ic)
            output['a'] = [output['a'][i] + 2 for i in range(len(output['a']))]
            correct_outputs.append(output)
        return s_1, s_3, correct_outputs

    def _get_cycled_graph_with_implicit_parallelization(self, initial_conditions):
        '''
        s_1 -> s_2 -> s_3 -> s_4
                   <-
        p_23 := a > 0
        p_22 := a <= 0
        all other predicates are dummy
        f_11 = f_22 = f_23 = f_34 := a + 1
        '''

        s_sub_1 = State('s_sub_1', array_keys_mapping={'a': 'a'})
        s_sub_2 = State('s_sub_2')
        s_sub_3 = State('s_sub_3')
        s_1 = State('s_1')
        s_2 = State('s_2')
        s_sub_1.connect_to(s_sub_2, edge=Edge(dummy_predicate, increment_a_edge))
        s_sub_2.connect_to(s_sub_2, edge=Edge(lambda d: d['a'] <= 0, increment_a_edge))
        s_sub_2.connect_to(s_sub_3, edge=Edge(lambda d: d['a'] > 0, increment_a_edge))
        subgraph = Graph(s_sub_1, s_sub_3)
        s_1.connect_to(s_2, edge=Edge(dummy_predicate, increment_a_array_edge))
        s_1.replace_with_graph(subgraph)
        correct_outputs = []
        for ic in initial_conditions:
            output = _create_data_from_dict(ic)
            output['a'] = [3 for i in range(len(output['a']))]
            correct_outputs.append(output)
        return s_1, s_2, correct_outputs

    def _run_graph(self, initial_state, term_state, initial_datas, invalid_initial_datas, correct_outputs):
        graph = Graph(initial_state, term_state)
        for initial_data, correct_output in zip(initial_datas, correct_outputs):
            print('Doing ic = {}'.format(initial_data))
            data = deepcopy(initial_data)
            okay = graph.run(data)
            #print(data['__EXCEPTION__'])
            if initial_data in invalid_initial_datas:
                self.assertEqual('__EXCEPTION__' in data, True)
                self.assertEqual(okay, False)
            else:
                self.assertEqual(okay, True)
                self.assertEqual(data, correct_output)

if __name__ == '__main__':
    unittest.main()