Commit dc6980a9 authored by Anton Pershin's avatar Anton Pershin

Added support of custom commands in config

parent 35017770
......@@ -8,21 +8,27 @@ import json
import socket
from stat import S_ISDIR
from abc import ABCMeta, abstractmethod
import logging
import comsdk.comaux as aux
class Host(object):
'''
Class storing all necessary information about the host of execution.
'''
def __init__(self):
self.programs = {}
self.commands = {}
def add_program(self, prog_name,
path_to_prog=None,
):
self.programs[prog_name] = path_to_prog
def add_command(self, cmd_name, cmd):
self.commands[cmd_name] = cmd
def get_program_launch_path(self, prog_name):
path_to_prog = self.programs[prog_name]
if path_to_prog is not None:
......@@ -30,6 +36,7 @@ class Host(object):
else:
return prog_name
class RemoteHost(Host):
'''
RemoteHost extends Host including information about ssh host and the number of cores.
......@@ -62,6 +69,7 @@ class RemoteHost(Host):
self.set_job_id = aux.load_function_from_module(state['job_setter'])
self.check_task_finished = aux.load_function_from_module(state['job_finished_checker'])
# Decorator
def enable_sftp(func):
def wrapped_func(self, *args, **kwds):
......@@ -69,6 +77,7 @@ def enable_sftp(func):
return func(self, *args, **kwds)
return wrapped_func
class BaseCommunication(metaclass=ABCMeta):
'''
BaseCommunication is an abstract class which can be used to implement the simplest access to a machine.
......@@ -128,6 +137,7 @@ class BaseCommunication(metaclass=ABCMeta):
where = '@' + self.machine_name if is_remote else ''
print('\tExecuting %s: %s' % (where, cmd))
class LocalCommunication(BaseCommunication):
def __init__(self, local_host, machine_name='laptop'):
super(LocalCommunication, self).__init__(local_host, machine_name)
......@@ -137,7 +147,7 @@ class LocalCommunication(BaseCommunication):
with open('config_research.json', 'r') as f:
conf = json.load(f)
local_host = Host()
_add_programs_from_config(local_host, conf['LOCAL_HOST'])
_add_programs_and_commands_from_config(local_host, conf['LOCAL_HOST'])
return LocalCommunication(local_host)
def execute(self, command, working_dir=None):
......@@ -160,6 +170,7 @@ class LocalCommunication(BaseCommunication):
def rm(self, target):
aux.rm(target)
class SshCommunication(BaseCommunication):
def __init__(self, remote_host, username, password, machine_name='', pkey=None):
if not isinstance(remote_host, RemoteHost):
......@@ -185,7 +196,7 @@ class SshCommunication(BaseCommunication):
sge_template_name=hostconf['sge_template_name'],
job_setter=hostconf['job_setter'],
job_finished_checker=hostconf['job_finished_checker'])
_add_programs_from_config(remote_host, hostconf)
_add_programs_and_commands_from_config(remote_host, hostconf)
return SshCommunication(remote_host, username=hostconf['username'],
password=hostconf['password'] if 'password' in hostconf else None,
machine_name=host_sid,
......@@ -305,6 +316,7 @@ class SshCommunication(BaseCommunication):
res = self.sftp_client.get(remote_path, local_path)
received = True
except FileNotFoundError as e:
logging.error('Cannot find file or directory "{}" => interrupt downloading'.format(remote_path))
if os.path.exists(local_path):
aux.rm(local_path)
raise
......@@ -329,6 +341,7 @@ class SshCommunication(BaseCommunication):
res = self.sftp_client.put(local_path, remote_path)
received = True
except FileNotFoundError as e:
logging.error('Cannot find file or directory "{}" => interrupt uploading'.format(local_path))
self.rm(remote_path)
raise
except (socket.timeout, socket.error, paramiko.sftp.SFTPError) as e:
......@@ -425,7 +438,8 @@ class SshCommunication(BaseCommunication):
class CommunicationError(Exception):
pass
def _add_programs_from_config(host, hostconf):
def _add_programs_and_commands_from_config(host, hostconf):
if 'custom_programs' in hostconf:
paths = hostconf['custom_programs']
for path, programs in paths.items():
......@@ -434,4 +448,7 @@ def _add_programs_from_config(host, hostconf):
if 'env_programs' in hostconf:
for program in hostconf['env_programs']:
host.add_program(program)
if 'custom_commands' in hostconf:
for cmd_name, cmd in hostconf['custom_commands'].items():
host.add_command(cmd_name, cmd)
......@@ -5,6 +5,9 @@
"@path_to_binaries@": ["@bin1@", "@bin2@", ...],
...
}
"custom_commands": {
"@command_name@": "@command itself@"
},
},
"REMOTE_HOSTS": {
"@remote_host_sid@": {
......@@ -19,6 +22,9 @@
"@path_to_binaries@": ["@bin1@", "@bin2@", ...],
...
},
"custom_commands": {
"@command_name@": "@command itself@"
},
"sge_template_name": "...",
"job_setter": "...",
"job_finished_checker": "..."
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment