Commit b54ff2ad authored by Anton Pershin's avatar Anton Pershin

Added graph for task creation

parent 00e5f8f3
import pickle
from datetime import date
from typing import Sequence, Optional, Mapping, Any
from typing import Sequence, Mapping
from typing_extensions import TypedDict
import json
from comsdk.comaux import *
from comsdk.communication import BaseCommunication, LocalCommunication, SshCommunication, Host
from comsdk.distributed_storage import *
from comsdk.edge import Edge, dummy_predicate
from comsdk.edge import Func, Edge, dummy_predicate
from comsdk.graph import Graph, State
CopiesList = TypedDict('CopiesList', {'path': str, 'new_name': str})
......@@ -62,6 +62,7 @@ class Research:
:param remote_research_root: path on the remote machine where research directories are searched for
"""
self._local_research_root = local_research_roots[0]
self._local_root = os.path.dirname(self._local_research_root)
self._remote_research_root = remote_research_root
self._tasks_number = 0
self._local_comm = LocalCommunication(Host()) # local communication created automatically, no need to pass it
......@@ -132,6 +133,10 @@ class Research:
def remote_research_path(self) -> str:
return os.path.join(self._remote_research_root, self._research_dir)
@property
def local_root(self) -> str:
return self._local_root
@property
def research_dir(self) -> str:
return self._research_dir
......@@ -250,6 +255,8 @@ class Research:
task_name = self._get_task_name_by_number(task_number)
rel_task_dir = os.path.join(self._research_dir, get_task_full_name(task_number, task_name))
if at_remote_host:
if self._remote_comm is None:
raise ValueError('Cannot get a task path on the remote: remote communication is not set up')
task_path = '{}/{}'.format(self._remote_research_root, rel_task_dir)
else:
task_path = self._distr_storage.get_dir_path(rel_task_dir)
......@@ -333,13 +340,28 @@ def retrieve_trailing_float_from_task_dir(task_dir: str) -> float:
class CreateTaskEdge(Edge):
def __init__(self, res, task_name_maker, predicate=dummy_predicate):
def __init__(self, res, task_name_maker, predicate=dummy_predicate, remote=False):
self._res = res
self._task_name_maker = task_name_maker
super().__init__(predicate, self.execute)
self._remote = remote
super().__init__(predicate, Func(func=self.execute))
def execute(self, data):
task_name = self._task_name_maker(data)
task_number = self._res.create_task(task_name)
data['__WORKING_DIR__'] = self._res.get_task_path(task_number)
if self._remote:
data['__REMOTE_WORKING_DIR__'] = self._res.get_task_path(task_number, at_remote_host=True)
class CreateTaskGraph(Graph):
def __init__(self, res, task_name_maker, array_keys_mapping=None, remote=False):
s_init, s_term = self.create_branch(res, task_name_maker, array_keys_mapping=array_keys_mapping, remote=remote)
super().__init__(s_init, s_term)
@staticmethod
def create_branch(res, task_name_maker, array_keys_mapping=None, remote=False):
s_init = State('READY_FOR_TASK_CREATION', array_keys_mapping=array_keys_mapping)
s_term = State('TASK_CREATED')
s_init.connect_to(s_term, edge=CreateTaskEdge(res, task_name_maker=task_name_maker, remote=remote))
return s_init, s_term
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment