communication.py 19.2 KB
Newer Older
1 2 3 4 5 6
import os
import os.path
import shutil
import subprocess
import shlex
import json
7
import socket
8 9
from stat import S_ISDIR
from abc import ABCMeta, abstractmethod
10
import logging
11

12 13
import paramiko

14
import comsdk.comaux as aux
15

16

17 18 19 20 21 22
class Host(object):
    '''
    Class storing all necessary information about the host of execution.
    '''
    def __init__(self):
        self.programs = {}
23
        self.commands = {}
24 25 26 27 28 29

    def add_program(self, prog_name,
                    path_to_prog=None,
                    ):
        self.programs[prog_name] = path_to_prog

30 31 32
    def add_command(self, cmd_name, cmd):
        self.commands[cmd_name] = cmd

33
    def get_program_launch_path(self, prog_name):
34 35 36 37 38
        if prog_name not in self.programs:
            raise ValueError(f'Program "{prog_name}" is not recognized. '
                             'Please add this program to "custom_programs" '
                             'in the corresponding host in the config file '
                             'if you want to use it.')
39 40
        path_to_prog = self.programs[prog_name]
        if path_to_prog is not None:
41 42
            print(self.programs[prog_name], prog_name)
            return self.join_path(self.programs[prog_name], prog_name)
43 44 45
        else:
            return prog_name

46 47 48
    def join_path(self, *path_list):
        return os.path.join(*path_list)

49

50 51 52 53 54 55 56 57
class RemoteHost(Host):
    '''
    RemoteHost extends Host including information about ssh host and the number of cores.
    '''
    def __init__(self, ssh_host, cores, sge_template_name, job_setter, job_finished_checker):
        self.ssh_host = ssh_host
        self.cores = cores
        self.sge_template_name = sge_template_name
58 59
        self.set_job_id = aux.load_function_from_module(job_setter)
        self.check_task_finished = aux.load_function_from_module(job_finished_checker)
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        self._job_setter = job_setter
        self._job_finished_checker = job_finished_checker
        super().__init__()

    def __getstate__(self):
        return {
            'ssh_host': self.ssh_host,
            'cores': self.cores,
            'programs': self.programs,
            'sge_template_name': self.sge_template_name,
            'job_setter': self._job_setter,
            'job_finished_checker': self._job_finished_checker,
        }

    def __setstate__(self, state):
        self.ssh_host = state['ssh_host']
        self.cores = state['cores']
        self.programs = state['programs']
        self.sge_template_name = state['sge_template_name']
79 80
        self.set_job_id = aux.load_function_from_module(state['job_setter'])
        self.check_task_finished = aux.load_function_from_module(state['job_finished_checker'])
81

82 83 84 85
    def join_path(self, *path_list):
        # For RemoteHost, we assume that it is posix-based
        return '/'.join(path_list)

86

87 88 89 90 91 92 93
# Decorator
def enable_sftp(func):
    def wrapped_func(self, *args, **kwds):
        self._init_sftp()
        return func(self, *args, **kwds)
    return wrapped_func

94

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
class BaseCommunication(metaclass=ABCMeta):
    '''
    BaseCommunication is an abstract class which can be used to implement the simplest access to a machine.
    A concrete class ought to use a concrete method of communication (e.g., OS API or ssh) allowing to access 
    the filesystem (copy and remove files) and execute a command line on the machine.

    Since a machine can be, in particular, the local machine, and at the same time we must always establish the communication between
    the local machine and a machine being communicated, we have to sort the terminology out. We shall call the latter a communicated 
    machine whilst the former remain the local machine.

    Generally, two types of files exchange are possible:
    (1) between the local machine and a communicated machine,
    (2) within a communicated machine.
    Since for now only copying implies this division, we introduce so called 'modes of copying': from_local, to_local 
    and all_on_communicated
    '''

    def __init__(self, host, machine_name):
        self.host = host
        self.machine_name = machine_name

    @abstractmethod
    def execute(self, command, working_dir=None):
        pass

    @abstractmethod
121
    def copy(self, from_, to_, mode='from_local', show_msg=False):
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
        '''
        Copies from_ to to_ which are interpreted according to mode:
        (1) from_local (default) -> from_ is local path, to_ is a path on a communicated machine
        (2) from_remote -> from_ is a path on a communicated machine, to_ local path
        (3) all_remote -> from_ and to_ are paths on a communicated machine

        from_ and to_ can be dirs or files according to the following combinations:
        (1) from_ is dir, to_ is dir
        (2) from_ is file, to_ is dir
        (3) from_ is file, to_ is file
        '''
        pass

    @abstractmethod
    def rm(self, target):
        '''
        Removes target which can be a dir or file
        '''
        pass

Anton Pershin's avatar
Anton Pershin committed
142 143
    def execute_program(self, prog_name, args_str, working_dir=None, chaining_command_at_start='',
                        chaining_command_at_end=''):
144
        prog_path = self.host.get_program_launch_path(prog_name)
Anton Pershin's avatar
Anton Pershin committed
145
        command = f'{chaining_command_at_start} {prog_path} {args_str} {chaining_command_at_end}'
146 147 148 149 150 151 152 153 154
        return self.execute(command, working_dir)

    def _print_copy_msg(self, from_, to_):
        print('\tCopying %s to %s' % (from_, to_))

    def _print_exec_msg(self, cmd, is_remote):
        where = '@' + self.machine_name if is_remote else ''
        print('\tExecuting %s: %s' % (where, cmd))

155

156 157 158 159 160 161 162 163 164
class LocalCommunication(BaseCommunication):
    def __init__(self, local_host, machine_name='laptop'):
        super(LocalCommunication, self).__init__(local_host, machine_name)

    @classmethod
    def create_from_config(cls):
        with open('config_research.json', 'r') as f:
            conf = json.load(f)
        local_host = Host()
165
        _add_programs_and_commands_from_config(local_host, conf['LOCAL_HOST'])
166 167 168
        return LocalCommunication(local_host)

    def execute(self, command, working_dir=None):
169 170 171 172 173 174 175 176 177 178
        if working_dir is None:
            command_line = command 
        else:
            if os.name == 'posix':
                command_line = 'cd {}; {}'.format(working_dir, command)
            elif os.name == 'nt':
                command_line = ''
                if working_dir[0] != 'C':
                    command_line += f'{working_dir[0]}: && '
                command_line += 'cd {} && {}'.format(working_dir, command)
179
        #self._print_exec_msg(command_line, is_remote=False)
180 181 182
        #res = subprocess.call([command_line], shell=True)
    #    print(command_line)
        res = subprocess.run(command_line, shell=True)
183 184
        return [], []

185
    def copy(self, from_, to_, mode='from_local', show_msg=False):
186 187 188
        '''
        Any mode is ignored since the copying shall be within a local machine anyway
        '''
189 190
        if show_msg:
            self._print_copy_msg(from_, to_)
191 192 193
        return cp(from_, to_)

    def rm(self, target):
194
        aux.rm(target)
195

196

197
class SshCommunication(BaseCommunication):
Anton Pershin's avatar
Anton Pershin committed
198
    def __init__(self, remote_host, username, password, machine_name='', pkey=None, execute_after_connection=None):
199 200 201 202 203
        if not isinstance(remote_host, RemoteHost):
            Exception('Only RemoteHost can be used to build SshCommunication')
        self.host = remote_host
        self.username = username
        self.password = password
204
        self.pkey = pkey
Anton Pershin's avatar
Anton Pershin committed
205
        self.execute_after_connection = execute_after_connection
206 207 208
        self.ssh_client = paramiko.SSHClient()
        self.sftp_client = None
        #self.main_dir = '/nobackup/mmap/research'
Anton Pershin's avatar
Anton Pershin committed
209
        super().__init__(self.host, machine_name)
210
        self.connect()
211 212 213 214 215 216 217 218 219 220 221 222
        paramiko.util.log_to_file('paramiko.log')

    @classmethod
    def create_from_config(cls, host_sid):
        with open('config_research.json', 'r') as f:
            conf = json.load(f)
        hostconf = conf['REMOTE_HOSTS'][host_sid]
        remote_host = RemoteHost(ssh_host=hostconf['ssh_host'],
                                 cores=hostconf['max_cores'],
                                 sge_template_name=hostconf['sge_template_name'],
                                 job_setter=hostconf['job_setter'],
                                 job_finished_checker=hostconf['job_finished_checker'])
223
        _add_programs_and_commands_from_config(remote_host, hostconf)
224
        return SshCommunication(remote_host, username=hostconf['username'],
225 226
                                             password=hostconf['password'] if 'password' in hostconf else None,
                                             machine_name=host_sid,
Anton Pershin's avatar
Anton Pershin committed
227 228
                                             pkey=hostconf['pkey'] if 'pkey' in hostconf else None,
                                             execute_after_connection=hostconf['execute_after_connection'] if 'execute_after_connection' in hostconf else None)
229 230 231 232 233 234

    def __getstate__(self):
        return {
            'host': self.host.__getstate__(),
            'username': self.username,
            'password': self.password,
235
            'pkey': self.pkey,
Anton Pershin's avatar
Anton Pershin committed
236
            'execute_after_connection': self.execute_after_connection,
237 238 239 240 241
        }

    def __setstate__(self, state):
        remote_host = RemoteHost.__new__(RemoteHost)
        remote_host.__setstate__(state['host'])
Anton Pershin's avatar
Anton Pershin committed
242 243
        self.__init__(remote_host, state['username'], state['password'], pkey=state['pkey'],
                      execute_after_connection=state['execute_after_connection'])
244 245 246 247 248

    def execute(self, command, working_dir=None):
        if self.ssh_client is None:
            raise Exception('Remote host is not set')

Anton Pershin's avatar
Anton Pershin committed
249
        self._print_exec_msg(command, is_remote=True)
250
        command_line = command if working_dir is None else 'cd {}; {}'.format(working_dir, command)
Anton Pershin's avatar
Anton Pershin committed
251 252
        command_line = command_line if self.execute_after_connection is None else f'{self.execute_after_connection}; {command_line}'
        print(command_line)
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

        def _cleanup():
            print('\t\tMSG: Reboot SSH client')
            self.reboot()
        cleanup = _cleanup
        received = False
        while not received:
            try:
                stdin, stdout, stderr = self.ssh_client.exec_command(command_line)
                received = True
            except (OSError, socket.timeout, socket.error, paramiko.sftp.SFTPError) as e:
                print('\t\tMSG: Catched {} exception while executing "{}"'.format(type(e).__name__, command_line))
                print('\t\tMSG: It says: {}'.format(e))
            else:
                cleanup = lambda: None
            cleanup()
Anton Pershin's avatar
Anton Pershin committed
269 270 271 272
        for line in stdout:
            print('\t\t' + line.strip('\n'))
        for line in stderr:
            print('\t\t' + line.strip('\n'))
273 274
        return stdout.readlines(), stderr.readlines()

275
    def copy(self, from_, to_, mode='from_local', show_msg=False):
276 277 278 279 280 281
        if self.ssh_client is None:
            raise Exception('Remote host is not set')
        self._init_sftp()

        new_path = None
        if mode == 'from_local':
282
            new_path = self._copy_from_local(from_, to_, show_msg)
283
        elif mode == 'from_remote':
284
            new_path = self._copy_from_remote(from_, to_, show_msg)
285
        elif mode == 'all_remote':
286 287
            if show_msg:
                self._print_copy_msg(self.machine_name + ':' + from_, self.machine_name + ':' + to_)
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
            self._mkdirp(to_)
            self.execute('cp -r %s %s' % (from_, to_))
        else:
            raise Exception("Incorrect mode '%s'" % mode)
        return new_path

    def rm(self, target):
        if self.ssh_client is None:
            raise Exception('Remote host is not set')
        self._init_sftp()
        self.execute('rm -r %s' % target)

    @enable_sftp
    def mkdir(self, path):
        self.sftp_client.mkdir(path)

    @enable_sftp
    def listdir(self, path_on_remote):
        return self.sftp_client.listdir(path_on_remote)

    @enable_sftp
    def _chdir(self, path=None):
        self.sftp_client.chdir(path)

    def _mkdirp(self, path):
        path_list = path.split('/')
        cur_dir = ''
        if (path_list[0] == '') or (path_list[0] == '~'): # path is absolute and relative to user's home dir => don't need to check obvious
            cur_dir = path_list.pop(0) + '/'
        start_creating = False # just to exclude unnecessary stat() calls when we catch non-existing dir
        for dir_ in path_list:
            if dir_ == '': # trailing slash or double slash, can skip
                continue
            cur_dir += dir_
            if start_creating or (not self._is_remote_dir(cur_dir)):
                self.mkdir(cur_dir)
                if not start_creating:
                    start_creating = True

            cur_dir += '/'

    @enable_sftp
    def _open(self, filename, mode='r'):
        return self.sftp_client.open(filename, mode)

    @enable_sftp
    def _get(self, remote_path, local_path):
335 336 337 338 339 340 341 342 343 344 345 346 347
        def _cleanup():
            print('\t\tMSG: Reboot SSH client')
            self.reboot()
            if os.path.exists(local_path):
                aux.rm(local_path)

        cleanup = _cleanup
        received = False
        while not received:
            try:
                res = self.sftp_client.get(remote_path, local_path)
                received = True
            except FileNotFoundError as e:
348
                logging.error('Cannot find file or directory "{}" => interrupt downloading'.format(remote_path))
349 350 351 352 353 354 355 356 357 358
                if os.path.exists(local_path):
                    aux.rm(local_path)
                raise
            except (socket.timeout, socket.error, paramiko.sftp.SFTPError) as e:
                print('\t\tMSG: Catched {} exception while getting "{}"'.format(type(e).__name__, remote_path))
                print('\t\tMSG: It says: {}'.format(e))
            else:
                cleanup = lambda: None
            cleanup()
        return res
359 360 361

    @enable_sftp
    def _put(self, local_path, remote_path):
362 363 364 365 366 367 368 369 370 371 372
        def _cleanup():
            print('\t\tMSG: Reboot SSH client')
            self.reboot()
            self.rm(remote_path)
        cleanup = _cleanup
        received = False
        while not received:
            try:
                res = self.sftp_client.put(local_path, remote_path)
                received = True
            except FileNotFoundError as e:
373
                logging.error('Cannot find file or directory "{}" => interrupt uploading'.format(local_path))
374 375 376 377 378 379 380 381 382
                self.rm(remote_path)
                raise
            except (socket.timeout, socket.error, paramiko.sftp.SFTPError) as e:
                print('\t\tMSG: Catched {} exception while putting "{}"'.format(type(e).__name__, remote_path))
                print('\t\tMSG: It says: {}'.format(e))
            else:
                cleanup = lambda: None
            cleanup()
        return res
383 384 385 386 387 388 389

    def _is_remote_dir(self, path):
        try:
            return S_ISDIR(self.sftp_client.stat(path).st_mode)
        except IOError:
            return False

390
    def _copy_from_local(self, from_, to_, show_msg=False):
391 392 393
        new_path_on_remote = to_ + '/' + os.path.basename(from_)
        if os.path.isfile(from_):
            self._mkdirp(to_)
394 395
            if show_msg:
                self._print_copy_msg(from_, self.machine_name + ':' + to_)
396 397 398 399
            self._put(from_, new_path_on_remote)
        elif os.path.isdir(from_):
            self.mkdir(new_path_on_remote)
            for dir_or_file in os.listdir(from_):
400
                self._copy_from_local(os.path.join(from_, dir_or_file), new_path_on_remote, show_msg)
401 402 403 404
        else:
            raise CommunicationError("Path %s does not exist" % from_)
        return new_path_on_remote

405
    def _copy_from_remote(self, from_, to_, show_msg=False):
406 407
        new_path_on_local = os.path.join(to_, os.path.basename(from_))
        if not self._is_remote_dir(from_):
408 409
            if show_msg:
                self._print_copy_msg(self.machine_name + ':' + from_, to_)
410 411 412 413
            self._get(from_, new_path_on_local)
        else:
            os.mkdir(new_path_on_local)
            for dir_or_file in self.sftp_client.listdir(from_):
414
                self._copy_from_remote(from_ + '/' + dir_or_file, new_path_on_local, show_msg)
415 416 417 418 419
        return new_path_on_local

    def disconnect(self):
        if self.sftp_client is not None:
            self.sftp_client.close()
420
            self.sftp_client = None
421 422
        self.ssh_client.close()

423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
    def connect(self):
        self.ssh_client.load_system_host_keys()
        self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        connected = False

        # read ssh config. We assume that all necessary re-routing are done there via ProxyCommand
        # only ProxyCommand is read; password should be passed explicitly to SshCommunication
        ssh_config = paramiko.SSHConfig()
        user_config_file = os.path.expanduser("~/.ssh/config")
        if os.path.exists(user_config_file):
            with open(user_config_file) as f:
                ssh_config.parse(f)

        user_config = ssh_config.lookup(self.host.ssh_host)
        sock = None
        if 'proxycommand' in user_config:
            sock = paramiko.ProxyCommand(user_config['proxycommand'])

        while not connected:
            try:
                if self.pkey is not None: # if a private key is given, first attempt to connect using it
                    self.ssh_client.connect(self.host.ssh_host, username=self.username, key_filename=self.pkey, timeout=10, sock=sock)
                else: # otherwise try to connect via password using it is given
Anton Pershin's avatar
Anton Pershin committed
446
                    print(self.host.ssh_host, self.username)
447 448 449 450 451 452 453 454 455 456
                    self.ssh_client.connect(self.host.ssh_host, username=self.username, password=self.password, look_for_keys=False, allow_agent=False, timeout=10, sock=sock)
                connected = True
            except socket.timeout as e:
                print('\t\tMSG: Catched {} exception while connecting'.format(type(e).__name__))
                print('\t\tMSG: It says: {}'.format(e))

        transport = self.ssh_client.get_transport()
        transport.packetizer.REKEY_BYTES = pow(2, 40) # 1TB max, this is a security degradation (otherwise we get "paramiko.ssh_exception.SSHException: Key-exchange timed out waiting for key negotiation")
        transport.packetizer.REKEY_PACKETS = pow(2, 40) # 1TB max, this is a security degradation (otherwise we get "paramiko.ssh_exception.SSHException: Key-exchange timed out waiting for key negotiation")

Anton Pershin's avatar
Anton Pershin committed
457 458 459
        if self.execute_after_connection is not None:
            self.execute(self.execute_after_connection)

460 461 462 463 464
    def reboot(self):
        self.disconnect()
        self.connect()
        self._init_sftp()

465 466 467
    def _init_sftp(self):
        if self.sftp_client is None:
            self.sftp_client = self.ssh_client.open_sftp()
468 469
            self.sftp_client.get_channel().settimeout(10)

470 471 472 473

class CommunicationError(Exception):
    pass

474 475

def _add_programs_and_commands_from_config(host, hostconf):
476 477 478 479 480 481 482 483
    if 'custom_programs' in hostconf:
        paths = hostconf['custom_programs']
        for path, programs in paths.items():
            for program in programs:
                host.add_program(program, path)
    if 'env_programs' in hostconf:
        for program in hostconf['env_programs']:
            host.add_program(program)
484 485 486
    if 'custom_commands' in hostconf:
        for cmd_name, cmd in hostconf['custom_commands'].items():
            host.add_command(cmd_name, cmd)
487