import collections
import os
from enum import Enum, auto
from functools import partial

import aux as aux

ImplicitParallelizationInfo = collections.namedtuple('ImplicitParallelizationInfo', ['array_keys_mapping', 'branches_number', 'branch_i'])

class Morphism:
    def __init__(self, edge, output_state):
        self.edge = edge
        self.output_state = output_state

    def morph(self, data, dynamic_keys_mapping={}):
        #print(dynamic_keys_mapping)
        self.edge.morph(data, dynamic_keys_mapping)
        #return self.output_state, None
        return self.output_state

class IdleRunType(Enum):
    INIT = auto()
    CLEANUP = auto()

class GraphFactory:
    def __init__(self):
        pass

    def create_state():
        pass

    def create_edge():
        # Here we should somehow pass the argument for "special" edges
        # Essentially, we change only io_mapping
        pass

    def make_graph():
        pass

class PluralGraphFactory:
    def __init__(self, plural_keys_mappings, parallel_graphs_number):
        self.plural_keys_mappings = plural_keys_mappings
        self.parallel_graphs_number = parallel_graphs_number
        self.init_state = None
        
    def create_state(state):
        if self.init_state == None:
            self.init_state = state

    def create_edge():
        # Here we should somehow pass the argument for "special" edges
        # Essentially, we change only io_mapping
        pass

    def make_graph():
        pass

class PluralState:
    def __init__(self, states):
        self.states = states
        pass

    def connect_to(self, term_states, edge):
        for init_state, term_state in zip(self.states, term_states):
            init_state.output_morphisms.append(Morphism(edge, term_state))

class Graph:
    '''
    Class describing a graph-based computational method. Graph execution must start from this object.     
    '''
    def __init__(self, init_state,
                 term_state=None,
                 ):
        self.init_state = init_state
        self.term_state = term_state
        if self.term_state is not None:
            self.term_state.is_term_state = True
        self._initialized = False

    def run(self, data):
        '''
        Goes through the graph and returns boolean denoting whether the graph has finished successfully.
        It runs twice -- the first run is idle (needed for initialization) and the second run is real.
        The input data will be augmented by metadata:
        1) '__CURRENT_WORKING_DIR__' -- absolute path to the current working directory as defined by the OS
        2) '__WORKING_DIR__' -- absolute path to the directory from which external binaries or resources will be launched.
        It will be set only if it is not yet set in data
        3) '__EXCEPTION__' if any error occurs
        '''
        self.init_graph(data)
        cur_state = self.init_state
        implicit_parallelization_info = None
        while cur_state is not None:
 #           print('1) In main loop', implicit_parallelization_info)
#            morph = _run_state(cur_state, data, implicit_parallelization_info)
            morph, implicit_parallelization_info = _run_state(cur_state, data, implicit_parallelization_info)
#            print('2) In main loop', implicit_parallelization_info)
            if '__EXCEPTION__' in data:
                return False
#            cur_state, implicit_parallelization_info = morph(data)
            cur_state = morph(data)
#            print(morph)
            if '__EXCEPTION__' in data:
                return False
        return True

    def init_graph(self, data):
        if not self._initialized:
            self.init_state.idle_run(IdleRunType.INIT, [self.init_state.name])
            self._initialized = True
        else:
            self.init_state.idle_run(IdleRunType.CLEANUP, [self.init_state.name])
        data['__CURRENT_WORKING_DIR__'] = os.getcwd()
        if not '__WORKING_DIR__' in data:
            data['__WORKING_DIR__'] = data['__CURRENT_WORKING_DIR__']

class State:
    __slots__ = [
        'name',
        'input_edges_number',
        'looped_edges_number',
        'activated_input_edges_number',
        'output_morphisms',
        'parallelization_policy',
        'selection_policy',
        'is_term_state',
        'array_keys_mapping',
        '_branching_states_history',
        '_proxy_state',
        ]
    def __init__(self, name, 
                 parallelization_policy=None,
                 selection_policy=None,
                 array_keys_mapping=None, # if array_keys_mapping is not None, we have implicit parallelization in this state
                 ):
        self.name = name
        self.parallelization_policy = SerialParallelizationPolicy() if parallelization_policy is None else parallelization_policy
        self.selection_policy = OnlyOneSelectionPolicy() if selection_policy is None else selection_policy
        self.array_keys_mapping = array_keys_mapping
        self.input_edges_number = 0
        self.looped_edges_number = 0
        self.activated_input_edges_number = 0
        self.output_morphisms = []
        self.is_term_state=False
        self._branching_states_history = None
        self._proxy_state=None

    def idle_run(self, idle_run_type, branching_states_history):
        if self._proxy_state is not None:
            return self._proxy_state.idle_run(idle_run_type, branching_states_history)
#        print('{} {} -> '.format(self.name, branching_states_history), end='')
        if idle_run_type == IdleRunType.INIT:
            self.input_edges_number += 1
            if self.input_edges_number != 1:
                if self._is_looped_branch(branching_states_history):
                    self.looped_edges_number += 1
                return # no need to go further if we already were there
            if self._branching_states_history is None:
                self._branching_states_history = branching_states_history
        elif idle_run_type == IdleRunType.CLEANUP:
            self.activated_input_edges_number = 0
#           print('\tCLEANUP STATE {}, active: {}, branches_story: {}'.format(self.name, self.activated_input_edges_number, self._branching_states_history))
            if self._branching_states_history is not None and self._is_looped_branch(branching_states_history):
#                print('\tqwer')
                self._branching_states_history = None
                return
            if self._branching_states_history is None:
                self._branching_states_history = branching_states_history
        else:
            self.activated_input_edges_number += 1 # BUG: here we need to choose somehow whether we proceed or not
#        if len(self.output_edges) == 0:
#            print('Terminate state found')
        if len(self.output_morphisms) == 1:
            self.output_morphisms[0].output_state.idle_run(idle_run_type, branching_states_history)
        else:
            for i, morphism in enumerate(self.output_morphisms):
                next_state = morphism.output_state
                next_state.idle_run(idle_run_type, branching_states_history + [next_state.name])

    def connect_to(self, term_state, edge):
        self.output_morphisms.append(Morphism(edge, term_state))
#        edge.set_output_state(term_state)
#        self.output_edges.append(edge)

    def replace_with_graph(self, graph):
        self._proxy_state = graph.init_state
        graph.term_state.output_morphisms = self.output_morphisms

    def run(self, data, implicit_parallelization_info=None):
        print('STATE {}, just entered, implicit_parallelization_info: {}'.format(self.name, implicit_parallelization_info))
        if self._proxy_state is not None:
            return self._proxy_state.run(data, implicit_parallelization_info)
        self._activate_input_edge(implicit_parallelization_info)
        #self.activated_input_edges_number += 1
        print('STATE {}, required input: {}, active: {}, looped: {}'.format(self.name, self.input_edges_number, self.activated_input_edges_number, self.looped_edges_number))
#        print('qwer')
        if not self._ready_to_morph(implicit_parallelization_info):
            return None, None # it means that this state waits for some incoming edges (it is a point of collision of several edges)
        self._reset_activity(implicit_parallelization_info)
        if self.is_term_state:
            implicit_parallelization_info = None
        #print(self.name)
        if len(self.output_morphisms) == 0:
            return morphism_to_termination, None
        predicate_values = []
        dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
        for morphism in self.output_morphisms:
            predicate_values.append(morphism.edge.predicate(data, dynamic_keys_mapping))
        selected_edge_indices = self.selection_policy.select(predicate_values)
        if not selected_edge_indices:
            raise GraphUnexpectedTermination(
                'State {}: Predicate values {} do not conform selection policy'.format(self.name, predicate_values))
        selected_morphisms = [self.output_morphisms[i] for i in selected_edge_indices]
        return self.parallelization_policy.make_morphism(selected_morphisms,
                                                         array_keys_mapping=self.array_keys_mapping,
                                                         implicit_parallelization_info=implicit_parallelization_info,), \
               implicit_parallelization_info

#        return self.parallelization_policy.make_morphism(selected_morphisms,
#                                                         array_keys_mapping=self.array_keys_mapping,
#                                                         implicit_parallelization_info=implicit_parallelization_info,)

    def _activate_input_edge(self, implicit_parallelization_info=None):
        if implicit_parallelization_info is None or self.is_term_state:
            self.activated_input_edges_number += 1
        else:
            if isinstance(self.activated_input_edges_number, int):
                self.activated_input_edges_number = [0 for i in range(implicit_parallelization_info.branches_number)]
            self.activated_input_edges_number[implicit_parallelization_info.branch_i] += 1

    def _ready_to_morph(self, implicit_parallelization_info=None):
        required_activated_input_edges_number = self.input_edges_number - self.looped_edges_number
        if implicit_parallelization_info is not None:
            if self.is_term_state:
                required_activated_input_edges_number = implicit_parallelization_info.branches_number
                return self.activated_input_edges_number == required_activated_input_edges_number
            return self.activated_input_edges_number[implicit_parallelization_info.branch_i] == required_activated_input_edges_number
        else:
            return self.activated_input_edges_number == required_activated_input_edges_number

#        if implicit_parallelization_info is None or self.is_term_state:
#            if self.is_term_state:
#                required_activated_input_edges_number = implicit_parallelization_info.branches_number
#            return self.activated_input_edges_number == required_activated_input_edges_number
#        else:
#            return self.activated_input_edges_number[implicit_parallelization_info.branch_i] == required_activated_input_edges_number

    def _reset_activity(self, implicit_parallelization_info=None):
        self._branching_states_history = None
        if self._ready_to_morph(implicit_parallelization_info) and self._has_loop():
            if implicit_parallelization_info is None or self.is_term_state:
                self.activated_input_edges_number -= 1
            else:
                self.activated_input_edges_number[implicit_parallelization_info.branch_i] -= 1
        else:
#            self.activated_input_edges_number = 0
            if implicit_parallelization_info is None or self.is_term_state:
                self.activated_input_edges_number = 0
            else:
                self.activated_input_edges_number[implicit_parallelization_info.branch_i] = 0

    def _is_looped_branch(self, branching_states_history):
        return set(self._branching_states_history).issubset(branching_states_history)

    def _has_loop(self):
        return self.looped_edges_number != 0

def morphism_to_termination(data):
    return None

class SerialParallelizationPolicy:
#    def __init__(self, data):
#        self.data = data
    def __init__(self):
        pass

#    def make_morphism(self, morphisms, array_keys_mapping=None, implicit_parallelization_info=None):
#        def _morph(data):
#            if array_keys_mapping is None:
#                dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
#                next_morphs = [partial(morphism.morph, dynamic_keys_mapping=dynamic_keys_mapping) for morphism in morphisms]
#                next_impl_para_infos = [implicit_parallelization_info for _ in morphisms]
# #               print('\t\t {}'.format(implicit_parallelization_infos))
#            else:
#                if len(morphisms) != 1:
#                    raise BadGraphStructure('Impossible to create implicit paralleilzation in the state with {} output edges'.format(len(morphisms)))
#                dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
#                proxy_data = aux.ProxyDict(data, keys_mappings=array_keys_mapping)
#                anykey = next(iter(array_keys_mapping.keys()))
#                implicit_branches_number = len(proxy_data[anykey])
#                next_morphs = []
#                next_impl_para_infos = []
#                for branch_i in range(implicit_branches_number):
#                    implicit_parallelization_info_ = ImplicitParallelizationInfo(array_keys_mapping, implicit_branches_number, branch_i)
#                    dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info_)
#                    next_morphs.append(partial(morphisms[0].morph, dynamic_keys_mapping=dynamic_keys_mapping))
#                    next_impl_para_infos.append(implicit_parallelization_info_)
#            cur_morphs = []
#            cur_impl_para_infos = []
#            while len(next_morphs) != 1:
#                cur_morphs[:] = next_morphs[:]
#                cur_impl_para_infos[:] = next_impl_para_infos[:]
#                del next_morphs[:]
#                del next_impl_para_infos[:]
#                # WE DO NOT UPDATE implicit_parallelization_infos !!!
#                for morph, impl_para_info in zip(cur_morphs, cur_impl_para_infos):
#                    next_state, _ = morph(data)
##                    print('\t next_state: {}, with impl para info: {}'.format(next_state.name, impl_para_info))
#                    if next_state is None:
#                        return None, None
#                    next_morph = _run_state(next_state, data, impl_para_info)
##                    print('\t next_morph: {}'.format(next_morph))
#                    if '__EXCEPTION__' in data:
#                        return None, None
#                    if next_morph is not None:
#                        next_morphs.append(next_morph)
#                        next_impl_para_infos.append(impl_para_info)
#                #print(len(next_morphs))
##            print('\t last morph: {}'.format(next_morphs[0]))
#            next_state, _ = next_morphs[0](data)
#            print(next_state.name, next_impl_para_infos[0])
#            return next_state, next_impl_para_infos[0]
#        return _morph

    def make_morphism(self, morphisms, array_keys_mapping=None, implicit_parallelization_info=None):
        def _morph(data):
            if array_keys_mapping is None:
                dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
                next_morphs = [partial(morphism.morph, dynamic_keys_mapping=dynamic_keys_mapping) for morphism in morphisms]
                next_impl_para_infos = [implicit_parallelization_info for _ in morphisms]
 #               print('\t\t {}'.format(implicit_parallelization_infos))
            else:
                if len(morphisms) != 1:
                    raise BadGraphStructure('Impossible to create implicit paralleilzation in the state with {} output edges'.format(len(morphisms)))
                dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
                proxy_data = aux.ProxyDict(data, keys_mappings=array_keys_mapping)
                anykey = next(iter(array_keys_mapping.keys()))
                implicit_branches_number = len(proxy_data[anykey])
                next_morphs = []
                next_impl_para_infos = []
                for branch_i in range(implicit_branches_number):
                    implicit_parallelization_info_ = ImplicitParallelizationInfo(array_keys_mapping, implicit_branches_number, branch_i)
                    dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info_)
#                    print(dynamic_keys_mapping)
                    next_morphs.append(partial(morphisms[0].morph, dynamic_keys_mapping=dynamic_keys_mapping))
                    next_impl_para_infos.append(implicit_parallelization_info_)
            cur_morphs = []
            cur_impl_para_infos = []
            #while len(next_morphs) != 1 or _is_implicitly_parallelized(next_impl_para_infos):
            while len(next_morphs) != 1 or _requires_joint_of_implicit_parallelization(array_keys_mapping, next_impl_para_infos):
                if next_impl_para_infos == []:
                    raise Exception(str(len(next_morphs)))
#                print(array_keys_mapping, next_impl_para_infos)
                cur_morphs[:] = next_morphs[:]
                cur_impl_para_infos[:] = next_impl_para_infos[:]
                del next_morphs[:]
                del next_impl_para_infos[:]
                for morph, impl_para_info in zip(cur_morphs, cur_impl_para_infos):
                    next_state = morph(data)
#                    print('\t next_state: {}, with impl para info: {}'.format(next_state.name, impl_para_info))
                    if next_state is None:
                        return None
                    next_morph, next_impl_para_info = _run_state(next_state, data, impl_para_info)
#                    print('\t next_morph: {}'.format(next_morph))
                    if '__EXCEPTION__' in data:
                        return None
                    if next_morph is not None:
                        next_morphs.append(next_morph)
                        next_impl_para_infos.append(next_impl_para_info)
#                print(array_keys_mapping, next_impl_para_infos)
                #print(len(next_morphs))
#            print('\t last morph: {}'.format(next_morphs[0]))
            next_state = next_morphs[0](data)
#            print(next_state.name, next_impl_para_infos[0])
            return next_state
        return _morph


class OnlyOneSelectionPolicy:
    def __init__(self):
        pass

    def select(self, predicate_values):
        trues_indices = _get_trues(predicate_values)
        if len(trues_indices) != 1:
            return None
        return trues_indices

class AllSelectionPolicy:
    def __init__(self):
        pass

    def select(self, predicate_values):
        trues_indices = _get_trues(predicate_values)
        if len(trues_indices) != len(predicate_values):
            return None
        return trues_indices

class BadGraphStructure(Exception):
    pass

class GraphUnexpectedTermination(Exception):
    pass

def _requires_joint_of_implicit_parallelization(array_keys_mapping, impl_para_infos):
    if array_keys_mapping is None:
        return False
    for obj in impl_para_infos:
        if obj is not None:
            return True
    return False

def _get_trues(boolean_list):
    return [i for i, val in enumerate(boolean_list) if val == True]

#def _run_state(state, data, implicit_parallelization_info=None):
#    try:
#        next_morphism = state.run(data, implicit_parallelization_info)
#    except GraphUnexpectedTermination as e:
#        data['__EXCEPTION__'] = str(e)
#        return None
#    return next_morphism

def _run_state(state, data, implicit_parallelization_info=None):
    try:
        next_morphism, next_impl_para_info = state.run(data, implicit_parallelization_info)
    except GraphUnexpectedTermination as e:
        data['__EXCEPTION__'] = str(e)
        return None, None
    return next_morphism, next_impl_para_info


def build_dynamic_keys_mapping(implicit_parallelization_info=None):
    if implicit_parallelization_info is None:
        return {}
    dynamic_keys_mapping = {}
    for key, keys_path in implicit_parallelization_info.array_keys_mapping.items():
        dynamic_keys_mapping[key] = aux.ArrayItemGetter(keys_path, implicit_parallelization_info.branch_i)
    return dynamic_keys_mapping