communication.py 17.5 KB
Newer Older
1 2 3 4 5 6 7
import os
import os.path
import shutil
import paramiko
import subprocess
import shlex
import json
8
import socket
9 10
from stat import S_ISDIR
from abc import ABCMeta, abstractmethod
11
import logging
12

13
import comsdk.comaux as aux
14

15

16 17 18 19 20 21
class Host(object):
    '''
    Class storing all necessary information about the host of execution.
    '''
    def __init__(self):
        self.programs = {}
22
        self.commands = {}
23 24 25 26 27 28

    def add_program(self, prog_name,
                    path_to_prog=None,
                    ):
        self.programs[prog_name] = path_to_prog

29 30 31
    def add_command(self, cmd_name, cmd):
        self.commands[cmd_name] = cmd

32 33 34 35 36 37 38
    def get_program_launch_path(self, prog_name):
        path_to_prog = self.programs[prog_name]
        if path_to_prog is not None:
            return self.programs[prog_name] + '/' + prog_name
        else:
            return prog_name

39

40 41 42 43 44 45 46 47
class RemoteHost(Host):
    '''
    RemoteHost extends Host including information about ssh host and the number of cores.
    '''
    def __init__(self, ssh_host, cores, sge_template_name, job_setter, job_finished_checker):
        self.ssh_host = ssh_host
        self.cores = cores
        self.sge_template_name = sge_template_name
48 49
        self.set_job_id = aux.load_function_from_module(job_setter)
        self.check_task_finished = aux.load_function_from_module(job_finished_checker)
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
        self._job_setter = job_setter
        self._job_finished_checker = job_finished_checker
        super().__init__()

    def __getstate__(self):
        return {
            'ssh_host': self.ssh_host,
            'cores': self.cores,
            'programs': self.programs,
            'sge_template_name': self.sge_template_name,
            'job_setter': self._job_setter,
            'job_finished_checker': self._job_finished_checker,
        }

    def __setstate__(self, state):
        self.ssh_host = state['ssh_host']
        self.cores = state['cores']
        self.programs = state['programs']
        self.sge_template_name = state['sge_template_name']
69 70
        self.set_job_id = aux.load_function_from_module(state['job_setter'])
        self.check_task_finished = aux.load_function_from_module(state['job_finished_checker'])
71

72

73 74 75 76 77 78 79
# Decorator
def enable_sftp(func):
    def wrapped_func(self, *args, **kwds):
        self._init_sftp()
        return func(self, *args, **kwds)
    return wrapped_func

80

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
class BaseCommunication(metaclass=ABCMeta):
    '''
    BaseCommunication is an abstract class which can be used to implement the simplest access to a machine.
    A concrete class ought to use a concrete method of communication (e.g., OS API or ssh) allowing to access 
    the filesystem (copy and remove files) and execute a command line on the machine.

    Since a machine can be, in particular, the local machine, and at the same time we must always establish the communication between
    the local machine and a machine being communicated, we have to sort the terminology out. We shall call the latter a communicated 
    machine whilst the former remain the local machine.

    Generally, two types of files exchange are possible:
    (1) between the local machine and a communicated machine,
    (2) within a communicated machine.
    Since for now only copying implies this division, we introduce so called 'modes of copying': from_local, to_local 
    and all_on_communicated
    '''

    def __init__(self, host, machine_name):
        self.host = host
        self.machine_name = machine_name

    @abstractmethod
    def execute(self, command, working_dir=None):
        pass

    @abstractmethod
107
    def copy(self, from_, to_, mode='from_local', show_msg=False):
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
        '''
        Copies from_ to to_ which are interpreted according to mode:
        (1) from_local (default) -> from_ is local path, to_ is a path on a communicated machine
        (2) from_remote -> from_ is a path on a communicated machine, to_ local path
        (3) all_remote -> from_ and to_ are paths on a communicated machine

        from_ and to_ can be dirs or files according to the following combinations:
        (1) from_ is dir, to_ is dir
        (2) from_ is file, to_ is dir
        (3) from_ is file, to_ is file
        '''
        pass

    @abstractmethod
    def rm(self, target):
        '''
        Removes target which can be a dir or file
        '''
        pass

    def execute_program(self, prog_name, args_str, working_dir=None):
        prog_path = self.host.get_program_launch_path(prog_name)
        command = '{} {}'.format(prog_path, args_str)
        return self.execute(command, working_dir)

    def _print_copy_msg(self, from_, to_):
        print('\tCopying %s to %s' % (from_, to_))

    def _print_exec_msg(self, cmd, is_remote):
        where = '@' + self.machine_name if is_remote else ''
        print('\tExecuting %s: %s' % (where, cmd))

140

141 142 143 144 145 146 147 148 149
class LocalCommunication(BaseCommunication):
    def __init__(self, local_host, machine_name='laptop'):
        super(LocalCommunication, self).__init__(local_host, machine_name)

    @classmethod
    def create_from_config(cls):
        with open('config_research.json', 'r') as f:
            conf = json.load(f)
        local_host = Host()
150
        _add_programs_and_commands_from_config(local_host, conf['LOCAL_HOST'])
151 152 153 154 155 156 157 158 159 160 161
        return LocalCommunication(local_host)

    def execute(self, command, working_dir=None):
        command_line = command if working_dir is None else 'cd {}; {}'.format(working_dir, command)
        #print('\t' + command_line)
        # use PIPEs to avoid breaking the child process when the parent process finishes
        # (works on Linux, solution for Windows is to add creationflags=0x00000010 instead of stdout, stderr, stdin)
        #self._print_exec_msg(command_line, is_remote=False)
        subprocess.call([command_line], shell=True)
        return [], []

162
    def copy(self, from_, to_, mode='from_local', show_msg=False):
163 164 165
        '''
        Any mode is ignored since the copying shall be within a local machine anyway
        '''
166 167
        if show_msg:
            self._print_copy_msg(from_, to_)
168 169 170
        return cp(from_, to_)

    def rm(self, target):
171
        aux.rm(target)
172

173

174
class SshCommunication(BaseCommunication):
175
    def __init__(self, remote_host, username, password, machine_name='', pkey=None):
176 177 178 179 180
        if not isinstance(remote_host, RemoteHost):
            Exception('Only RemoteHost can be used to build SshCommunication')
        self.host = remote_host
        self.username = username
        self.password = password
181
        self.pkey = pkey
182 183 184
        self.ssh_client = paramiko.SSHClient()
        self.sftp_client = None
        #self.main_dir = '/nobackup/mmap/research'
185
        self.connect()
186 187 188 189 190 191 192 193 194 195 196 197 198
        paramiko.util.log_to_file('paramiko.log')
        super().__init__(self.host, machine_name)

    @classmethod
    def create_from_config(cls, host_sid):
        with open('config_research.json', 'r') as f:
            conf = json.load(f)
        hostconf = conf['REMOTE_HOSTS'][host_sid]
        remote_host = RemoteHost(ssh_host=hostconf['ssh_host'],
                                 cores=hostconf['max_cores'],
                                 sge_template_name=hostconf['sge_template_name'],
                                 job_setter=hostconf['job_setter'],
                                 job_finished_checker=hostconf['job_finished_checker'])
199
        _add_programs_and_commands_from_config(remote_host, hostconf)
200
        return SshCommunication(remote_host, username=hostconf['username'],
201 202 203
                                             password=hostconf['password'] if 'password' in hostconf else None,
                                             machine_name=host_sid,
                                             pkey=hostconf['pkey'] if 'pkey' in hostconf else None)
204 205 206 207 208 209

    def __getstate__(self):
        return {
            'host': self.host.__getstate__(),
            'username': self.username,
            'password': self.password,
210
            'pkey': self.pkey,
211 212 213 214 215
        }

    def __setstate__(self, state):
        remote_host = RemoteHost.__new__(RemoteHost)
        remote_host.__setstate__(state['host'])
216
        self.__init__(remote_host, state['username'], state['password'], pkey=state['pkey'])
217 218 219 220 221 222 223

    def execute(self, command, working_dir=None):
        if self.ssh_client is None:
            raise Exception('Remote host is not set')

        #self._print_exec_msg(command, is_remote=True)
        command_line = command if working_dir is None else 'cd {}; {}'.format(working_dir, command)
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239

        def _cleanup():
            print('\t\tMSG: Reboot SSH client')
            self.reboot()
        cleanup = _cleanup
        received = False
        while not received:
            try:
                stdin, stdout, stderr = self.ssh_client.exec_command(command_line)
                received = True
            except (OSError, socket.timeout, socket.error, paramiko.sftp.SFTPError) as e:
                print('\t\tMSG: Catched {} exception while executing "{}"'.format(type(e).__name__, command_line))
                print('\t\tMSG: It says: {}'.format(e))
            else:
                cleanup = lambda: None
            cleanup()
240 241 242 243 244 245
        return stdout.readlines(), stderr.readlines()
#        for line in stdout:
#            print('\t\t' + line.strip('\n'))
#        for line in stderr:
#            print('\t\t' + line.strip('\n'))

246
    def copy(self, from_, to_, mode='from_local', show_msg=False):
247 248 249 250 251 252
        if self.ssh_client is None:
            raise Exception('Remote host is not set')
        self._init_sftp()

        new_path = None
        if mode == 'from_local':
253
            new_path = self._copy_from_local(from_, to_, show_msg)
254
        elif mode == 'from_remote':
255
            new_path = self._copy_from_remote(from_, to_, show_msg)
256
        elif mode == 'all_remote':
257 258
            if show_msg:
                self._print_copy_msg(self.machine_name + ':' + from_, self.machine_name + ':' + to_)
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
            self._mkdirp(to_)
            self.execute('cp -r %s %s' % (from_, to_))
        else:
            raise Exception("Incorrect mode '%s'" % mode)
        return new_path

    def rm(self, target):
        if self.ssh_client is None:
            raise Exception('Remote host is not set')
        self._init_sftp()
        self.execute('rm -r %s' % target)

    @enable_sftp
    def mkdir(self, path):
        self.sftp_client.mkdir(path)

    @enable_sftp
    def listdir(self, path_on_remote):
        return self.sftp_client.listdir(path_on_remote)

    @enable_sftp
    def _chdir(self, path=None):
        self.sftp_client.chdir(path)

    def _mkdirp(self, path):
        path_list = path.split('/')
        cur_dir = ''
        if (path_list[0] == '') or (path_list[0] == '~'): # path is absolute and relative to user's home dir => don't need to check obvious
            cur_dir = path_list.pop(0) + '/'
        start_creating = False # just to exclude unnecessary stat() calls when we catch non-existing dir
        for dir_ in path_list:
            if dir_ == '': # trailing slash or double slash, can skip
                continue
            cur_dir += dir_
            if start_creating or (not self._is_remote_dir(cur_dir)):
                self.mkdir(cur_dir)
                if not start_creating:
                    start_creating = True

            cur_dir += '/'

    @enable_sftp
    def _open(self, filename, mode='r'):
        return self.sftp_client.open(filename, mode)

    @enable_sftp
    def _get(self, remote_path, local_path):
306 307 308 309 310 311 312 313 314 315 316 317 318
        def _cleanup():
            print('\t\tMSG: Reboot SSH client')
            self.reboot()
            if os.path.exists(local_path):
                aux.rm(local_path)

        cleanup = _cleanup
        received = False
        while not received:
            try:
                res = self.sftp_client.get(remote_path, local_path)
                received = True
            except FileNotFoundError as e:
319
                logging.error('Cannot find file or directory "{}" => interrupt downloading'.format(remote_path))
320 321 322 323 324 325 326 327 328 329
                if os.path.exists(local_path):
                    aux.rm(local_path)
                raise
            except (socket.timeout, socket.error, paramiko.sftp.SFTPError) as e:
                print('\t\tMSG: Catched {} exception while getting "{}"'.format(type(e).__name__, remote_path))
                print('\t\tMSG: It says: {}'.format(e))
            else:
                cleanup = lambda: None
            cleanup()
        return res
330 331 332

    @enable_sftp
    def _put(self, local_path, remote_path):
333 334 335 336 337 338 339 340 341 342 343
        def _cleanup():
            print('\t\tMSG: Reboot SSH client')
            self.reboot()
            self.rm(remote_path)
        cleanup = _cleanup
        received = False
        while not received:
            try:
                res = self.sftp_client.put(local_path, remote_path)
                received = True
            except FileNotFoundError as e:
344
                logging.error('Cannot find file or directory "{}" => interrupt uploading'.format(local_path))
345 346 347 348 349 350 351 352 353
                self.rm(remote_path)
                raise
            except (socket.timeout, socket.error, paramiko.sftp.SFTPError) as e:
                print('\t\tMSG: Catched {} exception while putting "{}"'.format(type(e).__name__, remote_path))
                print('\t\tMSG: It says: {}'.format(e))
            else:
                cleanup = lambda: None
            cleanup()
        return res
354 355 356 357 358 359 360

    def _is_remote_dir(self, path):
        try:
            return S_ISDIR(self.sftp_client.stat(path).st_mode)
        except IOError:
            return False

361
    def _copy_from_local(self, from_, to_, show_msg=False):
362 363 364
        new_path_on_remote = to_ + '/' + os.path.basename(from_)
        if os.path.isfile(from_):
            self._mkdirp(to_)
365 366
            if show_msg:
                self._print_copy_msg(from_, self.machine_name + ':' + to_)
367 368 369 370
            self._put(from_, new_path_on_remote)
        elif os.path.isdir(from_):
            self.mkdir(new_path_on_remote)
            for dir_or_file in os.listdir(from_):
371
                self._copy_from_local(os.path.join(from_, dir_or_file), new_path_on_remote, show_msg)
372 373 374 375
        else:
            raise CommunicationError("Path %s does not exist" % from_)
        return new_path_on_remote

376
    def _copy_from_remote(self, from_, to_, show_msg=False):
377 378
        new_path_on_local = os.path.join(to_, os.path.basename(from_))
        if not self._is_remote_dir(from_):
379 380
            if show_msg:
                self._print_copy_msg(self.machine_name + ':' + from_, to_)
381 382 383 384
            self._get(from_, new_path_on_local)
        else:
            os.mkdir(new_path_on_local)
            for dir_or_file in self.sftp_client.listdir(from_):
385
                self._copy_from_remote(from_ + '/' + dir_or_file, new_path_on_local, show_msg)
386 387 388 389 390
        return new_path_on_local

    def disconnect(self):
        if self.sftp_client is not None:
            self.sftp_client.close()
391
            self.sftp_client = None
392 393
        self.ssh_client.close()

394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
    def connect(self):
        self.ssh_client.load_system_host_keys()
        self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        connected = False

        # read ssh config. We assume that all necessary re-routing are done there via ProxyCommand
        # only ProxyCommand is read; password should be passed explicitly to SshCommunication
        ssh_config = paramiko.SSHConfig()
        user_config_file = os.path.expanduser("~/.ssh/config")
        if os.path.exists(user_config_file):
            with open(user_config_file) as f:
                ssh_config.parse(f)

        user_config = ssh_config.lookup(self.host.ssh_host)
        sock = None
        if 'proxycommand' in user_config:
            sock = paramiko.ProxyCommand(user_config['proxycommand'])

        while not connected:
            try:
                if self.pkey is not None: # if a private key is given, first attempt to connect using it
                    self.ssh_client.connect(self.host.ssh_host, username=self.username, key_filename=self.pkey, timeout=10, sock=sock)
                else: # otherwise try to connect via password using it is given
                    self.ssh_client.connect(self.host.ssh_host, username=self.username, password=self.password, look_for_keys=False, allow_agent=False, timeout=10, sock=sock)
                connected = True
            except socket.timeout as e:
                print('\t\tMSG: Catched {} exception while connecting'.format(type(e).__name__))
                print('\t\tMSG: It says: {}'.format(e))

        transport = self.ssh_client.get_transport()
        transport.packetizer.REKEY_BYTES = pow(2, 40) # 1TB max, this is a security degradation (otherwise we get "paramiko.ssh_exception.SSHException: Key-exchange timed out waiting for key negotiation")
        transport.packetizer.REKEY_PACKETS = pow(2, 40) # 1TB max, this is a security degradation (otherwise we get "paramiko.ssh_exception.SSHException: Key-exchange timed out waiting for key negotiation")

    def reboot(self):
        self.disconnect()
        self.connect()
        self._init_sftp()

432 433 434
    def _init_sftp(self):
        if self.sftp_client is None:
            self.sftp_client = self.ssh_client.open_sftp()
435 436
            self.sftp_client.get_channel().settimeout(10)

437 438 439 440

class CommunicationError(Exception):
    pass

441 442

def _add_programs_and_commands_from_config(host, hostconf):
443 444 445 446 447 448 449 450
    if 'custom_programs' in hostconf:
        paths = hostconf['custom_programs']
        for path, programs in paths.items():
            for program in programs:
                host.add_program(program, path)
    if 'env_programs' in hostconf:
        for program in hostconf['env_programs']:
            host.add_program(program)
451 452 453
    if 'custom_commands' in hostconf:
        for cmd_name, cmd in hostconf['custom_commands'].items():
            host.add_command(cmd_name, cmd)
454