Commit b54ff2ad authored by Anton Pershin's avatar Anton Pershin

Added graph for task creation

parent 00e5f8f3
import pickle import pickle
from datetime import date from datetime import date
from typing import Sequence, Optional, Mapping, Any from typing import Sequence, Mapping
from typing_extensions import TypedDict from typing_extensions import TypedDict
import json
from comsdk.comaux import * from comsdk.comaux import *
from comsdk.communication import BaseCommunication, LocalCommunication, SshCommunication, Host from comsdk.communication import BaseCommunication, LocalCommunication, SshCommunication, Host
from comsdk.distributed_storage import * from comsdk.distributed_storage import *
from comsdk.edge import Edge, dummy_predicate from comsdk.edge import Func, Edge, dummy_predicate
from comsdk.graph import Graph, State
CopiesList = TypedDict('CopiesList', {'path': str, 'new_name': str}) CopiesList = TypedDict('CopiesList', {'path': str, 'new_name': str})
...@@ -62,6 +62,7 @@ class Research: ...@@ -62,6 +62,7 @@ class Research:
:param remote_research_root: path on the remote machine where research directories are searched for :param remote_research_root: path on the remote machine where research directories are searched for
""" """
self._local_research_root = local_research_roots[0] self._local_research_root = local_research_roots[0]
self._local_root = os.path.dirname(self._local_research_root)
self._remote_research_root = remote_research_root self._remote_research_root = remote_research_root
self._tasks_number = 0 self._tasks_number = 0
self._local_comm = LocalCommunication(Host()) # local communication created automatically, no need to pass it self._local_comm = LocalCommunication(Host()) # local communication created automatically, no need to pass it
...@@ -132,6 +133,10 @@ class Research: ...@@ -132,6 +133,10 @@ class Research:
def remote_research_path(self) -> str: def remote_research_path(self) -> str:
return os.path.join(self._remote_research_root, self._research_dir) return os.path.join(self._remote_research_root, self._research_dir)
@property
def local_root(self) -> str:
return self._local_root
@property @property
def research_dir(self) -> str: def research_dir(self) -> str:
return self._research_dir return self._research_dir
...@@ -250,6 +255,8 @@ class Research: ...@@ -250,6 +255,8 @@ class Research:
task_name = self._get_task_name_by_number(task_number) task_name = self._get_task_name_by_number(task_number)
rel_task_dir = os.path.join(self._research_dir, get_task_full_name(task_number, task_name)) rel_task_dir = os.path.join(self._research_dir, get_task_full_name(task_number, task_name))
if at_remote_host: if at_remote_host:
if self._remote_comm is None:
raise ValueError('Cannot get a task path on the remote: remote communication is not set up')
task_path = '{}/{}'.format(self._remote_research_root, rel_task_dir) task_path = '{}/{}'.format(self._remote_research_root, rel_task_dir)
else: else:
task_path = self._distr_storage.get_dir_path(rel_task_dir) task_path = self._distr_storage.get_dir_path(rel_task_dir)
...@@ -333,13 +340,28 @@ def retrieve_trailing_float_from_task_dir(task_dir: str) -> float: ...@@ -333,13 +340,28 @@ def retrieve_trailing_float_from_task_dir(task_dir: str) -> float:
class CreateTaskEdge(Edge): class CreateTaskEdge(Edge):
def __init__(self, res, task_name_maker, predicate=dummy_predicate): def __init__(self, res, task_name_maker, predicate=dummy_predicate, remote=False):
self._res = res self._res = res
self._task_name_maker = task_name_maker self._task_name_maker = task_name_maker
super().__init__(predicate, self.execute) self._remote = remote
super().__init__(predicate, Func(func=self.execute))
def execute(self, data): def execute(self, data):
task_name = self._task_name_maker(data) task_name = self._task_name_maker(data)
task_number = self._res.create_task(task_name) task_number = self._res.create_task(task_name)
data['__WORKING_DIR__'] = self._res.get_task_path(task_number) data['__WORKING_DIR__'] = self._res.get_task_path(task_number)
if self._remote:
data['__REMOTE_WORKING_DIR__'] = self._res.get_task_path(task_number, at_remote_host=True) data['__REMOTE_WORKING_DIR__'] = self._res.get_task_path(task_number, at_remote_host=True)
class CreateTaskGraph(Graph):
def __init__(self, res, task_name_maker, array_keys_mapping=None, remote=False):
s_init, s_term = self.create_branch(res, task_name_maker, array_keys_mapping=array_keys_mapping, remote=remote)
super().__init__(s_init, s_term)
@staticmethod
def create_branch(res, task_name_maker, array_keys_mapping=None, remote=False):
s_init = State('READY_FOR_TASK_CREATION', array_keys_mapping=array_keys_mapping)
s_term = State('TASK_CREATED')
s_init.connect_to(s_term, edge=CreateTaskEdge(res, task_name_maker=task_name_maker, remote=remote))
return s_init, s_term
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment