graph.py 17.2 KB
Newer Older
1
import collections
2 3 4
import os
from enum import Enum, auto
from functools import partial
5
import importlib as imp
6

7
import comsdk.aux as aux
8 9 10

ImplicitParallelizationInfo = collections.namedtuple('ImplicitParallelizationInfo', ['array_keys_mapping', 'branches_number', 'branch_i'])

11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
class Func():
    __slots__ = (
        'module',
        'func', 
        'name'
    )
    def __init__(self, module="", name="", dummy=True,func=None):
        self.module = module
        self.name = name
        if func is not None:
            self.func = func
        elif dummy:
            self.func = lambda data: data
        else:  
            self.func = getattr(imp.import_module(module), name)
26

27 28 29 30 31 32 33 34 35 36 37 38
    def __str__(self):
        return "{}.{}()".format(self.module, self.name)

class Selector(Func):
    __slots__=(
        'ntransf'
    )
    def __init__(self, ntransf, module="", name="", dummy=True):
        self.ntransf = ntransf
        super().__init__(module, name, func=(lambda x: [True for i in range(ntransf)]) if dummy else None, dummy=False)


39
class Transfer:
40
    def __init__(self, edge, output_state, index=None):
41 42
        self.edge = edge
        self.output_state = output_state
43
        self.index = index
44

45
    def transfer(self, data, dynamic_keys_mapping={}):
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
        #print(dynamic_keys_mapping)
        self.edge.morph(data, dynamic_keys_mapping)
        #return self.output_state, None
        return self.output_state

class IdleRunType(Enum):
    INIT = auto()
    CLEANUP = auto()

class PluralState:
    def __init__(self, states):
        self.states = states
        pass

    def connect_to(self, term_states, edge):
        for init_state, term_state in zip(self.states, term_states):
62
            init_state.transfers.append(Transfer(edge, term_state))
63 64

class Graph:
65 66 67 68 69 70 71 72 73 74
    '''
    Class describing a graph-based computational method. Graph execution must start from this object.     
    '''
    def __init__(self, init_state,
                 term_state=None,
                 ):
        self.init_state = init_state
        self.term_state = term_state
        if self.term_state is not None:
            self.term_state.is_term_state = True
75 76 77
        self._initialized = False

    def run(self, data):
78 79 80 81 82 83 84 85 86 87 88 89
        '''
        Goes through the graph and returns boolean denoting whether the graph has finished successfully.
        It runs twice -- the first run is idle (needed for initialization) and the second run is real.
        The input data will be augmented by metadata:
        1) '__CURRENT_WORKING_DIR__' -- absolute path to the current working directory as defined by the OS
        2) '__WORKING_DIR__' -- absolute path to the directory from which external binaries or resources will be launched.
        It will be set only if it is not yet set in data
        3) '__EXCEPTION__' if any error occurs
        '''
        self.init_graph(data)
        cur_state = self.init_state
        implicit_parallelization_info = None
90
        while cur_state is not None:
91 92
 #           print('1) In main loop', implicit_parallelization_info)
#            morph = _run_state(cur_state, data, implicit_parallelization_info)
93
            transfer_f, implicit_parallelization_info = _run_state(cur_state, data, implicit_parallelization_info)
94
#            print('2) In main loop', implicit_parallelization_info)
95
            if '__EXCEPTION__' in data:
96 97
                return False
#            cur_state, implicit_parallelization_info = morph(data)
98
            cur_state = transfer_f(data)
99
#            print(morph)
100
            if '__EXCEPTION__' in data:
101 102 103
                return False
        return True

104
    def init_graph(self, data={}):
105 106 107 108 109 110 111 112
        if not self._initialized:
            self.init_state.idle_run(IdleRunType.INIT, [self.init_state.name])
            self._initialized = True
        else:
            self.init_state.idle_run(IdleRunType.CLEANUP, [self.init_state.name])
        data['__CURRENT_WORKING_DIR__'] = os.getcwd()
        if not '__WORKING_DIR__' in data:
            data['__WORKING_DIR__'] = data['__CURRENT_WORKING_DIR__']
113 114 115 116

class State:
    __slots__ = [
        'name',
117
        'input_edges_number', #output_edges_number == len(transfers)
118 119
        'looped_edges_number',
        'activated_input_edges_number',
120
        'transfers',
121
        'parallelization_policy',
122
        'selector',
123 124
        'is_term_state',
        'array_keys_mapping',
125
        '_branching_states_history',
126
        '_proxy_state',
127 128 129
        ]
    def __init__(self, name, 
                 parallelization_policy=None,
130
                 selector=None,
131
                 array_keys_mapping=None, # if array_keys_mapping is not None, we have implicit parallelization in this state
132 133
                 ):
        self.name = name
134
        self.parallelization_policy = SerialParallelizationPolicy() if parallelization_policy is None else parallelization_policy
135
        self.selector = lambda  x: True if selector is None else selector
136
        self.array_keys_mapping = array_keys_mapping
137 138 139
        self.input_edges_number = 0
        self.looped_edges_number = 0
        self.activated_input_edges_number = 0
140
        self.transfers = []
141
        self.is_term_state=False
142
        self._branching_states_history = None
143
        self._proxy_state=None
144

145 146 147 148
    def idle_run(self, idle_run_type, branching_states_history):
        if self._proxy_state is not None:
            return self._proxy_state.idle_run(idle_run_type, branching_states_history)
        if idle_run_type == IdleRunType.INIT:
149 150 151 152 153 154 155
            self.input_edges_number += 1
            if self.input_edges_number != 1:
                if self._is_looped_branch(branching_states_history):
                    self.looped_edges_number += 1
                return # no need to go further if we already were there
            if self._branching_states_history is None:
                self._branching_states_history = branching_states_history
156 157 158 159 160 161 162 163 164
        elif idle_run_type == IdleRunType.CLEANUP:
            self.activated_input_edges_number = 0
#           print('\tCLEANUP STATE {}, active: {}, branches_story: {}'.format(self.name, self.activated_input_edges_number, self._branching_states_history))
            if self._branching_states_history is not None and self._is_looped_branch(branching_states_history):
#                print('\tqwer')
                self._branching_states_history = None
                return
            if self._branching_states_history is None:
                self._branching_states_history = branching_states_history
165 166 167 168
        else:
            self.activated_input_edges_number += 1 # BUG: here we need to choose somehow whether we proceed or not
#        if len(self.output_edges) == 0:
#            print('Terminate state found')
169 170
        if len(self.transfers) == 1:
            self.transfers[0].output_state.idle_run(idle_run_type, branching_states_history)
171
        else:
172 173
            for i, transfer in enumerate(self.transfers):
                next_state = transfer.output_state
174
                next_state.idle_run(idle_run_type, branching_states_history + [next_state.name])
175 176

    def connect_to(self, term_state, edge):
177
        self.transfers.append(Transfer(edge, term_state))
178 179
#        edge.set_output_state(term_state)
#        self.output_edges.append(edge)
180

181 182
    def replace_with_graph(self, graph):
        self._proxy_state = graph.init_state
183
        graph.term_state.transfers = self.transfers
184 185 186 187 188 189 190 191 192

    def run(self, data, implicit_parallelization_info=None):
        print('STATE {}, just entered, implicit_parallelization_info: {}'.format(self.name, implicit_parallelization_info))
        if self._proxy_state is not None:
            return self._proxy_state.run(data, implicit_parallelization_info)
        self._activate_input_edge(implicit_parallelization_info)
        #self.activated_input_edges_number += 1
        print('STATE {}, required input: {}, active: {}, looped: {}'.format(self.name, self.input_edges_number, self.activated_input_edges_number, self.looped_edges_number))
#        print('qwer')
193
        if not self._ready_to_transfer(implicit_parallelization_info):
194 195 196 197 198
            return None, None # it means that this state waits for some incoming edges (it is a point of collision of several edges)
        self._reset_activity(implicit_parallelization_info)
        if self.is_term_state:
            implicit_parallelization_info = None
        #print(self.name)
199 200
        if len(self.transfers) == 0:
            return transfer_to_termination, None
201
        predicate_values = []
202
        dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
203 204
        for transfer in self.transfers:
            predicate_values.append(transfer.edge.predicate(data, dynamic_keys_mapping))
205
        selected_edge_indices = self.selector.func(predicate_values)
206 207 208
        if not selected_edge_indices:
            raise GraphUnexpectedTermination(
                'State {}: Predicate values {} do not conform selection policy'.format(self.name, predicate_values))
209 210
        selected_transfers = [self.transfers[i] for i in selected_edge_indices]
        return self.parallelization_policy.make_transfer_func(selected_transfers,
211 212 213
                                                         array_keys_mapping=self.array_keys_mapping,
                                                         implicit_parallelization_info=implicit_parallelization_info,), \
               implicit_parallelization_info
214 215


216 217 218 219 220 221 222
    def _activate_input_edge(self, implicit_parallelization_info=None):
        if implicit_parallelization_info is None or self.is_term_state:
            self.activated_input_edges_number += 1
        else:
            if isinstance(self.activated_input_edges_number, int):
                self.activated_input_edges_number = [0 for i in range(implicit_parallelization_info.branches_number)]
            self.activated_input_edges_number[implicit_parallelization_info.branch_i] += 1
223

224
    def _ready_to_transfer(self, implicit_parallelization_info=None):
225 226 227 228 229 230 231 232
        required_activated_input_edges_number = self.input_edges_number - self.looped_edges_number
        if implicit_parallelization_info is not None:
            if self.is_term_state:
                required_activated_input_edges_number = implicit_parallelization_info.branches_number
                return self.activated_input_edges_number == required_activated_input_edges_number
            return self.activated_input_edges_number[implicit_parallelization_info.branch_i] == required_activated_input_edges_number
        else:
            return self.activated_input_edges_number == required_activated_input_edges_number
233

234 235 236 237 238 239
#        if implicit_parallelization_info is None or self.is_term_state:
#            if self.is_term_state:
#                required_activated_input_edges_number = implicit_parallelization_info.branches_number
#            return self.activated_input_edges_number == required_activated_input_edges_number
#        else:
#            return self.activated_input_edges_number[implicit_parallelization_info.branch_i] == required_activated_input_edges_number
240

241 242
    def _reset_activity(self, implicit_parallelization_info=None):
        self._branching_states_history = None
243
        if self._ready_to_transfer(implicit_parallelization_info) and self._has_loop():
244 245 246 247 248 249 250 251 252 253
            if implicit_parallelization_info is None or self.is_term_state:
                self.activated_input_edges_number -= 1
            else:
                self.activated_input_edges_number[implicit_parallelization_info.branch_i] -= 1
        else:
#            self.activated_input_edges_number = 0
            if implicit_parallelization_info is None or self.is_term_state:
                self.activated_input_edges_number = 0
            else:
                self.activated_input_edges_number[implicit_parallelization_info.branch_i] = 0
254

255 256
    def _is_looped_branch(self, branching_states_history):
        return set(self._branching_states_history).issubset(branching_states_history)
257

258 259
    def _has_loop(self):
        return self.looped_edges_number != 0
260

261
def transfer_to_termination(data):
262 263 264 265 266 267 268 269
    return None

class SerialParallelizationPolicy:
#    def __init__(self, data):
#        self.data = data
    def __init__(self):
        pass

270
    def make_transfer_func(self, morphisms, array_keys_mapping=None, implicit_parallelization_info=None):
271
        def _morph(data):
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
            if array_keys_mapping is None:
                dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
                next_morphs = [partial(morphism.morph, dynamic_keys_mapping=dynamic_keys_mapping) for morphism in morphisms]
                next_impl_para_infos = [implicit_parallelization_info for _ in morphisms]
 #               print('\t\t {}'.format(implicit_parallelization_infos))
            else:
                if len(morphisms) != 1:
                    raise BadGraphStructure('Impossible to create implicit paralleilzation in the state with {} output edges'.format(len(morphisms)))
                dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
                proxy_data = aux.ProxyDict(data, keys_mappings=array_keys_mapping)
                anykey = next(iter(array_keys_mapping.keys()))
                implicit_branches_number = len(proxy_data[anykey])
                next_morphs = []
                next_impl_para_infos = []
                for branch_i in range(implicit_branches_number):
                    implicit_parallelization_info_ = ImplicitParallelizationInfo(array_keys_mapping, implicit_branches_number, branch_i)
                    dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info_)
#                    print(dynamic_keys_mapping)
                    next_morphs.append(partial(morphisms[0].morph, dynamic_keys_mapping=dynamic_keys_mapping))
                    next_impl_para_infos.append(implicit_parallelization_info_)
            cur_morphs = []
            cur_impl_para_infos = []
            #while len(next_morphs) != 1 or _is_implicitly_parallelized(next_impl_para_infos):
            while len(next_morphs) != 1 or _requires_joint_of_implicit_parallelization(array_keys_mapping, next_impl_para_infos):
                if next_impl_para_infos == []:
                    raise Exception(str(len(next_morphs)))
#                print(array_keys_mapping, next_impl_para_infos)
                cur_morphs[:] = next_morphs[:]
                cur_impl_para_infos[:] = next_impl_para_infos[:]
                del next_morphs[:]
                del next_impl_para_infos[:]
                for morph, impl_para_info in zip(cur_morphs, cur_impl_para_infos):
304
                    next_state = morph(data)
305
#                    print('\t next_state: {}, with impl para info: {}'.format(next_state.name, impl_para_info))
306 307
                    if next_state is None:
                        return None
308 309
                    next_morph, next_impl_para_info = _run_state(next_state, data, impl_para_info)
#                    print('\t next_morph: {}'.format(next_morph))
310 311
                    if '__EXCEPTION__' in data:
                        return None
312 313 314 315 316 317 318 319
                    if next_morph is not None:
                        next_morphs.append(next_morph)
                        next_impl_para_infos.append(next_impl_para_info)
#                print(array_keys_mapping, next_impl_para_infos)
                #print(len(next_morphs))
#            print('\t last morph: {}'.format(next_morphs[0]))
            next_state = next_morphs[0](data)
#            print(next_state.name, next_impl_para_infos[0])
320 321 322
            return next_state
        return _morph

323

324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
class OnlyOneSelectionPolicy:
    def __init__(self):
        pass

    def select(self, predicate_values):
        trues_indices = _get_trues(predicate_values)
        if len(trues_indices) != 1:
            return None
        return trues_indices

class AllSelectionPolicy:
    def __init__(self):
        pass

    def select(self, predicate_values):
        trues_indices = _get_trues(predicate_values)
        if len(trues_indices) != len(predicate_values):
            return None
        return trues_indices

class BadGraphStructure(Exception):
    pass

class GraphUnexpectedTermination(Exception):
    pass

350 351 352 353 354 355 356 357
def _requires_joint_of_implicit_parallelization(array_keys_mapping, impl_para_infos):
    if array_keys_mapping is None:
        return False
    for obj in impl_para_infos:
        if obj is not None:
            return True
    return False

358 359 360
def _get_trues(boolean_list):
    return [i for i, val in enumerate(boolean_list) if val == True]

361 362 363 364 365 366 367 368 369
#def _run_state(state, data, implicit_parallelization_info=None):
#    try:
#        next_morphism = state.run(data, implicit_parallelization_info)
#    except GraphUnexpectedTermination as e:
#        data['__EXCEPTION__'] = str(e)
#        return None
#    return next_morphism

def _run_state(state, data, implicit_parallelization_info=None):
370
    try:
371
        next_morphism, next_impl_para_info = state.run(data, implicit_parallelization_info)
372 373
    except GraphUnexpectedTermination as e:
        data['__EXCEPTION__'] = str(e)
374 375 376 377 378 379 380 381 382 383 384
        return None, None
    return next_morphism, next_impl_para_info


def build_dynamic_keys_mapping(implicit_parallelization_info=None):
    if implicit_parallelization_info is None:
        return {}
    dynamic_keys_mapping = {}
    for key, keys_path in implicit_parallelization_info.array_keys_mapping.items():
        dynamic_keys_mapping[key] = aux.ArrayItemGetter(keys_path, implicit_parallelization_info.branch_i)
    return dynamic_keys_mapping