graph.py 18.6 KB
Newer Older
1
import collections
2 3 4
import os
from enum import Enum, auto
from functools import partial
5
import importlib as imp
6

7
import comsdk.misc as aux
8 9 10

ImplicitParallelizationInfo = collections.namedtuple('ImplicitParallelizationInfo', ['array_keys_mapping', 'branches_number', 'branch_i'])

11 12

class Func:
13 14
    __slots__ = (
        'module',
15 16
        'func',
        'comment',
17 18
        'name'
    )
19 20

    def __init__(self, module="", name="", dummy=False, func=None, comment=''):
21 22
        self.module = module
        self.name = name
23 24
        self.comment = comment.replace("\0", " ") if comment is not None else ""
        if module == "" or name == "" or module is None or name is None:
25
            dummy = True
26 27 28 29
        if func is not None:
            self.func = func
        elif dummy:
            self.func = lambda data: data
30
        else: 
31 32 33 34 35
            print("LOADING function {} from {} module".format(name, module) )
            try:
                self.func = getattr(imp.import_module(module), name)
            except Exception:
                raise Exception("Could not load function {} from {} module".format(name, module))
36

37
    def __str__(self):
38 39
        if self.module == "" or self.name == "":
            return self.func.__name__
Savva Golubitsky's avatar
Savva Golubitsky committed
40
        return "{}_{}".format(self.module, self.name)
41

42

43
class Selector(Func):
44
    def __init__(self, ntransf, module="", name="", dummy=False):
45
        if module == "" and name == "":
46
            dummy = True
47
        self.dummy = dummy
48
        super().__init__(module, name, func=(lambda x: [True for i in range(ntransf)]) if dummy else None)
49

Savva Golubitsky's avatar
Savva Golubitsky committed
50
    def __str__(self):
51
        if self.module == "" or self.name == "":
Savva Golubitsky's avatar
Savva Golubitsky committed
52 53
            return ''
        return "{}_{}".format(self.module, self.name)
54 55


56
class Transfer:
57
    def __init__(self, edge, output_state, order=0):
58 59
        self.edge = edge
        self.output_state = output_state
Savva Golubitsky's avatar
Savva Golubitsky committed
60
        self.order = order
61

62
    def transfer(self, data, dynamic_keys_mapping={}):
63 64 65
        self.edge.morph(data, dynamic_keys_mapping)
        return self.output_state

66

67 68 69 70
class IdleRunType(Enum):
    INIT = auto()
    CLEANUP = auto()

71

72 73 74 75 76 77 78
class PluralState:
    def __init__(self, states):
        self.states = states
        pass

    def connect_to(self, term_states, edge):
        for init_state, term_state in zip(self.states, term_states):
79
            init_state.transfers.append(Transfer(edge, term_state))
80

81

82
class Graph:
83 84 85 86 87 88 89 90 91 92
    '''
    Class describing a graph-based computational method. Graph execution must start from this object.     
    '''
    def __init__(self, init_state,
                 term_state=None,
                 ):
        self.init_state = init_state
        self.term_state = term_state
        if self.term_state is not None:
            self.term_state.is_term_state = True
93 94 95
        self._initialized = False

    def run(self, data):
96 97 98 99 100 101 102 103 104 105 106 107
        '''
        Goes through the graph and returns boolean denoting whether the graph has finished successfully.
        It runs twice -- the first run is idle (needed for initialization) and the second run is real.
        The input data will be augmented by metadata:
        1) '__CURRENT_WORKING_DIR__' -- absolute path to the current working directory as defined by the OS
        2) '__WORKING_DIR__' -- absolute path to the directory from which external binaries or resources will be launched.
        It will be set only if it is not yet set in data
        3) '__EXCEPTION__' if any error occurs
        '''
        self.init_graph(data)
        cur_state = self.init_state
        implicit_parallelization_info = None
108
        while cur_state is not None:
109 110
 #           print('1) In main loop', implicit_parallelization_info)
#            morph = _run_state(cur_state, data, implicit_parallelization_info)
111
            transfer_f, implicit_parallelization_info = _run_state(cur_state, data, implicit_parallelization_info)
112
#            print('2) In main loop', implicit_parallelization_info)
113
            if '__EXCEPTION__' in data:
114 115
                return False
#            cur_state, implicit_parallelization_info = morph(data)
116
            cur_state = transfer_f(data)
117
#            print(morph)
118
            if '__EXCEPTION__' in data:
119 120 121
                return False
        return True

122
    def init_graph(self, data={}):
123 124 125 126 127 128 129 130
        if not self._initialized:
            self.init_state.idle_run(IdleRunType.INIT, [self.init_state.name])
            self._initialized = True
        else:
            self.init_state.idle_run(IdleRunType.CLEANUP, [self.init_state.name])
        data['__CURRENT_WORKING_DIR__'] = os.getcwd()
        if not '__WORKING_DIR__' in data:
            data['__WORKING_DIR__'] = data['__CURRENT_WORKING_DIR__']
Savva Golubitsky's avatar
Savva Golubitsky committed
131

132

133 134 135
class State:
    __slots__ = [
        'name',
136
        'input_edges_number', #output_edges_number == len(transfers)
137 138
        'looped_edges_number',
        'activated_input_edges_number',
139
        'transfers',
140
        'parallelization_policy',
141
        'selector',
142 143
        'is_term_state',
        'array_keys_mapping',
144
        '_branching_states_history',
145
        '_proxy_state',
Savva Golubitsky's avatar
Savva Golubitsky committed
146 147
        'possible_branches',
        'comment'
148 149 150
        ]
    def __init__(self, name, 
                 parallelization_policy=None,
151
                 selector=None,
152
                 array_keys_mapping=None, # if array_keys_mapping is not None, we have implicit parallelization in this state
153 154
                 ):
        self.name = name
155
        self.parallelization_policy = SerialParallelizationPolicy() if parallelization_policy is None else parallelization_policy
156
        self.selector = Selector(1) if selector is None else selector
157
        self.array_keys_mapping = array_keys_mapping
158 159 160
        self.input_edges_number = 0
        self.looped_edges_number = 0
        self.activated_input_edges_number = 0
161
        self.transfers = []
162 163
        self.possible_branches = []
        self.is_term_state = False
164
        self._branching_states_history = None
165
        self._proxy_state = None
166
        self.comment = None
167

168
    def idle_run(self, idle_run_type, branching_states_history):
169 170 171
        def __sort_by_order(tr):
            return tr.edge.order
        self.transfers.sort(key = __sort_by_order)
172 173 174
        # print(self.name)
        # for t in self.transfers:
            # print("\t", t.edge.order, t.edge.pred_name, t.edge.morph_name)
175 176 177
        if self._proxy_state is not None:
            return self._proxy_state.idle_run(idle_run_type, branching_states_history)
        if idle_run_type == IdleRunType.INIT:
178 179 180 181 182 183 184
            self.input_edges_number += 1
            if self.input_edges_number != 1:
                if self._is_looped_branch(branching_states_history):
                    self.looped_edges_number += 1
                return # no need to go further if we already were there
            if self._branching_states_history is None:
                self._branching_states_history = branching_states_history
185 186 187 188 189 190 191
        elif idle_run_type == IdleRunType.CLEANUP:
            self.activated_input_edges_number = 0
            if self._branching_states_history is not None and self._is_looped_branch(branching_states_history):
                self._branching_states_history = None
                return
            if self._branching_states_history is None:
                self._branching_states_history = branching_states_history
192 193
        else:
            self.activated_input_edges_number += 1 # BUG: here we need to choose somehow whether we proceed or not
194 195
        # if len(self.transfers) == 0:
            # print('Terminate state found')
196 197
        if len(self.transfers) == 1:
            self.transfers[0].output_state.idle_run(idle_run_type, branching_states_history)
198
        else:
199 200
            for i, transfer in enumerate(self.transfers):
                next_state = transfer.output_state
201
                next_state.idle_run(idle_run_type, branching_states_history + [next_state.name])
202

203 204 205
    def connect_to(self, term_state, edge=None, comment=None):
        if comment is not None or comment != "":
            self.comment = comment
206
        self.transfers.append(Transfer(edge, term_state))
207
        self.selector = Selector(len(self.transfers))
208 209
#        edge.set_output_state(term_state)
#        self.output_edges.append(edge)
210

211 212
    def replace_with_graph(self, graph):
        self._proxy_state = graph.init_state
213
        graph.term_state.transfers = self.transfers
214
        graph.term_state.selector = self.selector
215 216

    def run(self, data, implicit_parallelization_info=None):
217
        print('STATE {}\n\tjust entered, implicit_parallelization_info: {}'.format(self.name, implicit_parallelization_info))
218
        # print('\t{}'.format(data))
219 220 221 222
        if self._proxy_state is not None:
            return self._proxy_state.run(data, implicit_parallelization_info)
        self._activate_input_edge(implicit_parallelization_info)
        #self.activated_input_edges_number += 1
223
        print('\trequired input: {}, active: {}, looped: {}'.format(self.input_edges_number, self.activated_input_edges_number, self.looped_edges_number))
224
#        print('qwer')
225
        if not self._ready_to_transfer(implicit_parallelization_info):
226 227 228 229
            return None, None # it means that this state waits for some incoming edges (it is a point of collision of several edges)
        self._reset_activity(implicit_parallelization_info)
        if self.is_term_state:
            implicit_parallelization_info = None
230 231
        if len(self.transfers) == 0:
            return transfer_to_termination, None
232
        dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
233
        selected_edges = self.selector.func(data)
234
        if not selected_edges:
235
            raise GraphUnexpectedTermination(
236
                "STATE {}: error in selector: {} ".format(self.name, selected_edges))
237 238 239 240 241 242 243 244 245 246 247 248
#        selected_transfers = [self.transfers[i] for i, _ in enumerate(selected_edges) if selected_edges[i]]
#        for transf in selected_transfers:
#            if not transf.edge.predicate(data, dynamic_keys_mapping):
#                raise Exception("\tERROR: predicate {} returns {} running from state {}\n data{}".format(transf.edge.pred_f.name,transf.edge.predicate(data, dynamic_keys_mapping), self.name, data))
        selected_transfers = [self.transfers[i] for i, _ in enumerate(selected_edges)
                              if selected_edges[i] and self.transfers[i].edge.predicate(data, dynamic_keys_mapping)]
        if not selected_transfers:
            raise GraphUnexpectedTermination('\tERROR: no transfer function has been '
                                             'selected out of {} ones. Predicate values are {}. '
                                             'Selector values are {}.'.format(len(self.transfers),
                                                                              [t.edge.predicate(data, dynamic_keys_mapping) for t in self.transfers],
                                                                              selected_edges))
249
        return self.parallelization_policy.make_transfer_func(selected_transfers,
250 251 252
                                                              array_keys_mapping=self.array_keys_mapping,
                                                              implicit_parallelization_info=implicit_parallelization_info,
                                                              state=self), \
253
               implicit_parallelization_info
254

255 256 257 258 259 260 261
    def _activate_input_edge(self, implicit_parallelization_info=None):
        if implicit_parallelization_info is None or self.is_term_state:
            self.activated_input_edges_number += 1
        else:
            if isinstance(self.activated_input_edges_number, int):
                self.activated_input_edges_number = [0 for i in range(implicit_parallelization_info.branches_number)]
            self.activated_input_edges_number[implicit_parallelization_info.branch_i] += 1
262

263
    def _ready_to_transfer(self, implicit_parallelization_info=None):
264 265 266 267 268 269 270 271
        required_activated_input_edges_number = self.input_edges_number - self.looped_edges_number
        if implicit_parallelization_info is not None:
            if self.is_term_state:
                required_activated_input_edges_number = implicit_parallelization_info.branches_number
                return self.activated_input_edges_number == required_activated_input_edges_number
            return self.activated_input_edges_number[implicit_parallelization_info.branch_i] == required_activated_input_edges_number
        else:
            return self.activated_input_edges_number == required_activated_input_edges_number
272

273 274 275 276 277 278
#        if implicit_parallelization_info is None or self.is_term_state:
#            if self.is_term_state:
#                required_activated_input_edges_number = implicit_parallelization_info.branches_number
#            return self.activated_input_edges_number == required_activated_input_edges_number
#        else:
#            return self.activated_input_edges_number[implicit_parallelization_info.branch_i] == required_activated_input_edges_number
279

280 281
    def _reset_activity(self, implicit_parallelization_info=None):
        self._branching_states_history = None
282
        if self._ready_to_transfer(implicit_parallelization_info) and self._has_loop():
283 284 285 286 287 288 289 290 291 292
            if implicit_parallelization_info is None or self.is_term_state:
                self.activated_input_edges_number -= 1
            else:
                self.activated_input_edges_number[implicit_parallelization_info.branch_i] -= 1
        else:
#            self.activated_input_edges_number = 0
            if implicit_parallelization_info is None or self.is_term_state:
                self.activated_input_edges_number = 0
            else:
                self.activated_input_edges_number[implicit_parallelization_info.branch_i] = 0
293

294 295
    def _is_looped_branch(self, branching_states_history):
        return set(self._branching_states_history).issubset(branching_states_history)
296

297 298
    def _has_loop(self):
        return self.looped_edges_number != 0
299

300

301
def transfer_to_termination(data):
302 303 304 305 306 307 308 309
    return None

class SerialParallelizationPolicy:
#    def __init__(self, data):
#        self.data = data
    def __init__(self):
        pass

310
    def make_transfer_func(self, transfers, array_keys_mapping=None, implicit_parallelization_info=None, state=None):
311
        def _morph(data):
312
            # print("MORPHING FROM {}".format(state.name))
313 314
            if array_keys_mapping is None:
                dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
315 316
                next_transfers = [partial(t.transfer, dynamic_keys_mapping=dynamic_keys_mapping) for t in transfers]
                next_impl_para_infos = [implicit_parallelization_info for _ in transfers]
317 318
 #               print('\t\t {}'.format(implicit_parallelization_infos))
            else:
319 320 321
                if len(transfers) != 1:
                    raise BadGraphStructure('Impossible to create implicit paralleilzation in the state '
                                            'with {} output edges'.format(len(transfers)))
322 323 324 325
                dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info)
                proxy_data = aux.ProxyDict(data, keys_mappings=array_keys_mapping)
                anykey = next(iter(array_keys_mapping.keys()))
                implicit_branches_number = len(proxy_data[anykey])
326
                next_transfers = []
327 328 329 330 331
                next_impl_para_infos = []
                for branch_i in range(implicit_branches_number):
                    implicit_parallelization_info_ = ImplicitParallelizationInfo(array_keys_mapping, implicit_branches_number, branch_i)
                    dynamic_keys_mapping = build_dynamic_keys_mapping(implicit_parallelization_info_)
#                    print(dynamic_keys_mapping)
332 333
                    #next_transfers.append(partial(transfers[0].edge.morph, dynamic_keys_mapping=dynamic_keys_mapping))
                    next_transfers.append(partial(transfers[0].transfer, dynamic_keys_mapping=dynamic_keys_mapping))
334
                    next_impl_para_infos.append(implicit_parallelization_info_)
335
            cur_transfers = []
336
            cur_impl_para_infos = []
337 338
            #while len(next_transfers) != 1 or _is_implicitly_parallelized(next_impl_para_infos):
            while len(next_transfers) != 1 or _requires_joint_of_implicit_parallelization(array_keys_mapping, next_impl_para_infos):
339
                if next_impl_para_infos == []:
340
                    raise Exception("Morphs count on state {} is {}".format(state.name, str(len(next_transfers))))
341
#                print(array_keys_mapping, next_impl_para_infos)
342
                cur_transfers[:] = next_transfers[:]
343
                cur_impl_para_infos[:] = next_impl_para_infos[:]
344
                del next_transfers[:]
345
                del next_impl_para_infos[:]
346 347
                for t, impl_para_info in zip(cur_transfers, cur_impl_para_infos):
                    next_state = t(data)
348
#                    print('\t next_state: {}, with impl para info: {}'.format(next_state.name, impl_para_info))
349 350
                    if next_state is None:
                        return None
351
                    next_t, next_impl_para_info = _run_state(next_state, data, impl_para_info)
352
#                    print('\t next_morph: {}'.format(next_morph))
353 354
                    if '__EXCEPTION__' in data:
                        return None
355 356
                    if next_t is not None:
                        next_transfers.append(next_t)
357 358
                        next_impl_para_infos.append(next_impl_para_info)
#                print(array_keys_mapping, next_impl_para_infos)
359 360 361
                #print(len(next_transfers))
#            print('\t last morph: {}'.format(next_transfers[0]))
            next_state = next_transfers[0](data)
362
#            print(next_state.name, next_impl_para_infos[0])
363 364 365
            return next_state
        return _morph

366

367 368 369
class BadGraphStructure(Exception):
    pass

370

371 372 373
class GraphUnexpectedTermination(Exception):
    pass

374

375 376 377 378 379 380 381 382
def _requires_joint_of_implicit_parallelization(array_keys_mapping, impl_para_infos):
    if array_keys_mapping is None:
        return False
    for obj in impl_para_infos:
        if obj is not None:
            return True
    return False

383

384 385 386
def _get_trues(boolean_list):
    return [i for i, val in enumerate(boolean_list) if val == True]

387 388

def _run_state(state, data, implicit_parallelization_info=None):
389
    try:
390
        next_morphism, next_impl_para_info = state.run(data, implicit_parallelization_info)
391 392
    except GraphUnexpectedTermination as e:
        data['__EXCEPTION__'] = str(e)
393 394 395 396 397 398 399 400 401 402
        return None, None
    return next_morphism, next_impl_para_info


def build_dynamic_keys_mapping(implicit_parallelization_info=None):
    if implicit_parallelization_info is None:
        return {}
    dynamic_keys_mapping = {}
    for key, keys_path in implicit_parallelization_info.array_keys_mapping.items():
        dynamic_keys_mapping[key] = aux.ArrayItemGetter(keys_path, implicit_parallelization_info.branch_i)
403
    return dynamic_keys_mapping