# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Plugin for transport over SSH (and SFTP for file transfer)."""
# pylint: disable=too-many-lines
import glob
import io
import os
import re
from stat import S_ISDIR, S_ISREG
import click
import paramiko
from aiida.cmdline.params import options
from aiida.cmdline.params.types.path import AbsolutePathOrEmptyParamType
from aiida.common.escaping import escape_for_bash
from ..transport import Transport, TransportInternalError
__all__ = ('parse_sshconfig', 'convert_to_bool', 'SshTransport')
[docs]def parse_sshconfig(computername):
"""
Return the ssh configuration for a given computer name.
This parses the ``.ssh/config`` file in the home directory and
returns the part of configuration of the given computer name.
:param computername: the computer name for which we want the configuration.
"""
config = paramiko.SSHConfig()
try:
with open(os.path.expanduser('~/.ssh/config'), encoding='utf8') as fhandle:
config.parse(fhandle)
except IOError:
# No file found, so empty configuration
pass
return config.lookup(computername)
[docs]def convert_to_bool(string):
"""
Convert a string passed in the CLI to a valid bool.
:return: the parsed bool value.
:raise ValueError: If the value is not parsable as a bool
"""
upstring = str(string).upper()
if upstring in ['Y', 'YES', 'T', 'TRUE']:
return True
if upstring in ['N', 'NO', 'F', 'FALSE']:
return False
raise ValueError('Invalid boolean value provided')
[docs]class SshTransport(Transport): # pylint: disable=too-many-public-methods
"""
Support connection, command execution and data transfer to remote computers via SSH+SFTP.
"""
# Valid keywords accepted by the connect method of paramiko.SSHClient
# I disable 'password' and 'pkey' to avoid these data to get logged in the
# aiida log file.
_valid_connect_options = [
(
'username', {
'prompt': 'User name',
'help': 'Login user name on the remote machine.',
'non_interactive_default': True
}
),
(
'port',
{
'option': options.PORT,
'prompt': 'Port number',
'non_interactive_default': True,
},
),
(
'look_for_keys', {
'default': True,
'switch': True,
'prompt': 'Look for keys',
'help': 'Automatically look for private keys in the ~/.ssh folder.',
'non_interactive_default': True
}
),
(
'key_filename', {
'type': AbsolutePathOrEmptyParamType(dir_okay=False, exists=True),
'prompt': 'SSH key file',
'help': 'Absolute path to your private SSH key. Leave empty to use the path set in the SSH config.',
'non_interactive_default': True
}
),
(
'timeout', {
'type': int,
'prompt': 'Connection timeout in s',
'help': 'Time in seconds to wait for connection before giving up. Leave empty to use default value.',
'non_interactive_default': True
}
),
(
'allow_agent', {
'default': False,
'switch': True,
'prompt': 'Allow ssh agent',
'help': 'Switch to allow or disallow using an SSH agent.',
'non_interactive_default': True
}
),
(
'proxy_jump', {
'prompt':
'SSH proxy jump',
'help':
'SSH proxy jump for tunneling through other SSH hosts.'
' Use a comma-separated list of hosts of the form [user@]host[:port].'
' If user or port are not specified for a host, the user & port values from the target host are used.'
' This option must be provided explicitly and is not parsed from the SSH config file when left empty.',
'non_interactive_default':
True
}
), # Managed 'manually' in connect
(
'proxy_command', {
'prompt':
'SSH proxy command',
'help':
'SSH proxy command for tunneling through a proxy server.'
' For tunneling through another SSH host, consider using the "SSH proxy jump" option instead!'
' Leave empty to parse the proxy command from the SSH config file.',
'non_interactive_default':
True
}
), # Managed 'manually' in connect
(
'compress', {
'default': True,
'switch': True,
'prompt': 'Compress file transfers',
'help': 'Turn file transfer compression on or off.',
'non_interactive_default': True
}
),
(
'gss_auth', {
'default': False,
'type': bool,
'prompt': 'GSS auth',
'help': 'Enable when using GSS kerberos token to connect.',
'non_interactive_default': True
}
),
(
'gss_kex', {
'default': False,
'type': bool,
'prompt': 'GSS kex',
'help': 'GSS kex for kerberos, if not configured in SSH config file.',
'non_interactive_default': True
}
),
(
'gss_deleg_creds', {
'default': False,
'type': bool,
'prompt': 'GSS deleg_creds',
'help': 'GSS deleg_creds for kerberos, if not configured in SSH config file.',
'non_interactive_default': True
}
),
(
'gss_host', {
'prompt': 'GSS host',
'help': 'GSS host for kerberos, if not configured in SSH config file.',
'non_interactive_default': True
}
),
# for Kerberos support through python-gssapi
]
_valid_connect_params = [i[0] for i in _valid_connect_options]
# Valid parameters for the ssh transport
# For each param, a class method with name
# _convert_PARAMNAME_fromstring
# should be defined, that returns the value converted from a string to
# a correct type, or raise a ValidationError
#
# moreover, if you want to help in the default configuration, you can
# define a _get_PARAMNAME_suggestion_string
# to return a suggestion; it must accept only one parameter, being a Computer
# instance
_valid_auth_options = _valid_connect_options + [
(
'load_system_host_keys', {
'default': True,
'switch': True,
'prompt': 'Load system host keys',
'help': 'Load system host keys from default SSH location.',
'non_interactive_default': True
}
),
(
'key_policy', {
'default': 'RejectPolicy',
'type': click.Choice(['RejectPolicy', 'WarningPolicy', 'AutoAddPolicy']),
'prompt': 'Key policy',
'help': 'SSH key policy if host is not known.',
'non_interactive_default': True
}
)
]
# Max size of log message to print in _exec_command_internal.
# Unlimited by default, but can be cropped by a subclass
# if too large commands are sent, clogging the outputs or logs
_MAX_EXEC_COMMAND_LOG_SIZE = None
[docs] @classmethod
def _get_username_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
"""
import getpass
config = parse_sshconfig(computer.hostname)
# Either the configured user in the .ssh/config, or the current username
return str(config.get('user', getpass.getuser()))
[docs] @classmethod
def _get_port_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
"""
config = parse_sshconfig(computer.hostname)
# Either the configured user in the .ssh/config, or the default SSH port
return str(config.get('port', 22))
[docs] @classmethod
def _get_key_filename_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
"""
config = parse_sshconfig(computer.hostname)
try:
identities = config['identityfile']
# In paramiko > 0.10, identity file is a list of strings.
if isinstance(identities, str):
identity = identities
elif isinstance(identities, (list, tuple)):
if not identities:
# An empty list should not be provided; to be sure,
# anyway, behave as if no identityfile were defined
raise KeyError
# By default we suggest only the first one
identity = identities[0]
else:
# If the parser provides an unknown type, just skip to
# the 'except KeyError' section, as if no identityfile
# were provided (hopefully, this should never happen)
raise KeyError
except KeyError:
# No IdentityFile defined: return an empty string
return ''
return os.path.expanduser(identity)
[docs] @classmethod
def _get_timeout_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
Provide 60s as a default timeout for connections.
"""
config = parse_sshconfig(computer.hostname)
return str(config.get('connecttimeout', '60'))
[docs] @classmethod
def _get_allow_agent_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
"""
config = parse_sshconfig(computer.hostname)
return convert_to_bool(str(config.get('allow_agent', 'yes')))
[docs] @classmethod
def _get_look_for_keys_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
"""
config = parse_sshconfig(computer.hostname)
return convert_to_bool(str(config.get('look_for_keys', 'yes')))
[docs] @classmethod
def _get_proxy_command_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
"""
config = parse_sshconfig(computer.hostname)
# Either the configured user in the .ssh/config, or the default SSH port
raw_string = str(config.get('proxycommand', ''))
# Note: %h and %p get already automatically substituted with
# hostname and port by the config parser!
pieces = raw_string.split()
new_pieces = []
for piece in pieces:
if '>' in piece:
# If there is a piece with > to readdress stderr or stdout,
# skip from here on (anything else can only be readdressing)
break
new_pieces.append(piece)
return ' '.join(new_pieces)
[docs] @classmethod
def _get_proxy_jump_suggestion_string(cls, _):
"""
Return an empty suggestion since Paramiko does not parse ProxyJump from the SSH config.
"""
return ''
[docs] @classmethod
def _get_compress_suggestion_string(cls, computer): # pylint: disable=unused-argument
"""
Return a suggestion for the specific field.
"""
return 'True'
[docs] @classmethod
def _get_load_system_host_keys_suggestion_string(cls, computer): # pylint: disable=unused-argument
"""
Return a suggestion for the specific field.
"""
return 'True'
[docs] @classmethod
def _get_key_policy_suggestion_string(cls, computer): # pylint: disable=unused-argument
"""
Return a suggestion for the specific field.
"""
return 'RejectPolicy'
[docs] @classmethod
def _get_gss_auth_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
"""
config = parse_sshconfig(computer.hostname)
return convert_to_bool(str(config.get('gssapiauthentication', 'no')))
[docs] @classmethod
def _get_gss_kex_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
"""
config = parse_sshconfig(computer.hostname)
return convert_to_bool(str(config.get('gssapikeyexchange', 'no')))
[docs] @classmethod
def _get_gss_deleg_creds_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
"""
config = parse_sshconfig(computer.hostname)
return convert_to_bool(str(config.get('gssapidelegatecredentials', 'no')))
[docs] @classmethod
def _get_gss_host_suggestion_string(cls, computer):
"""
Return a suggestion for the specific field.
"""
config = parse_sshconfig(computer.hostname)
return str(config.get('gssapihostname', computer.hostname))
[docs] def __init__(self, *args, **kwargs):
"""
Initialize the SshTransport class.
:param machine: the machine to connect to
:param load_system_host_keys: (optional, default False)
if False, do not load the system host keys
:param key_policy: (optional, default = paramiko.RejectPolicy())
the policy to use for unknown keys
Other parameters valid for the ssh connect function (see the
self._valid_connect_params list) are passed to the connect
function (as port, username, password, ...); taken from the
accepted paramiko.SSHClient.connect() params.
"""
super().__init__(*args, **kwargs)
self._sftp = None
self._proxy = None
self._proxies = []
self._machine = kwargs.pop('machine')
self._client = paramiko.SSHClient()
self._load_system_host_keys = kwargs.pop('load_system_host_keys', False)
if self._load_system_host_keys:
self._client.load_system_host_keys()
self._missing_key_policy = kwargs.pop('key_policy', 'RejectPolicy') # This is paramiko default
if self._missing_key_policy == 'RejectPolicy':
self._client.set_missing_host_key_policy(paramiko.RejectPolicy())
elif self._missing_key_policy == 'WarningPolicy':
self._client.set_missing_host_key_policy(paramiko.WarningPolicy())
elif self._missing_key_policy == 'AutoAddPolicy':
self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
else:
raise ValueError(
'Unknown value of the key policy, allowed values '
'are: RejectPolicy, WarningPolicy, AutoAddPolicy'
)
self._connect_args = {}
for k in self._valid_connect_params:
try:
self._connect_args[k] = kwargs.pop(k)
except KeyError:
pass
[docs] def open(self): # pylint: disable=too-many-branches,too-many-statements
"""
Open a SSHClient to the machine possibly using the parameters given
in the __init__.
Also opens a sftp channel, ready to be used.
The current working directory is set explicitly, so it is not None.
:raise aiida.common.InvalidOperation: if the channel is already open
"""
from paramiko.ssh_exception import SSHException
from aiida.common.exceptions import InvalidOperation
from aiida.transports.util import _DetachedProxyCommand
if self._is_open:
raise InvalidOperation('Cannot open the transport twice')
# Open a SSHClient
connection_arguments = self._connect_args.copy()
if 'key_filename' in connection_arguments and not connection_arguments['key_filename']:
connection_arguments.pop('key_filename')
proxyjumpstring = connection_arguments.pop('proxy_jump', None)
proxycmdstring = connection_arguments.pop('proxy_command', None)
if proxyjumpstring and proxycmdstring:
raise ValueError('The SSH proxy jump and SSH proxy command options can not be used together')
if proxyjumpstring:
matcher = re.compile(r'^(?:(?P<username>[^@]+)@)?(?P<host>[^@:]+)(?::(?P<port>\d+))?\s*$')
try:
# don't use a generator here to have everything evaluated
proxies = [matcher.match(s).groupdict() for s in proxyjumpstring.split(',')]
except AttributeError:
raise ValueError('The given configuration for the SSH proxy jump option could not be parsed')
# proxy_jump supports a list of jump hosts, each jump host is another Paramiko SSH connection
# but when opening a forward channel on a connection, we have to give the next hop.
# So we go through adjacent pairs and by adding the final target to the list we make it universal.
for proxy, target in zip(
proxies, proxies[1:] + [{
'host': self._machine,
'port': connection_arguments.get('port', 22),
}]
):
proxy_connargs = connection_arguments.copy()
if proxy['username']:
proxy_connargs['username'] = proxy['username']
if proxy['port']:
proxy_connargs['port'] = int(proxy['port'])
if not target['port']: # the target port for the channel can not be None
target['port'] = connection_arguments.get('port', 22)
proxy_client = paramiko.SSHClient()
if self._load_system_host_keys:
proxy_client.load_system_host_keys()
if self._missing_key_policy == 'RejectPolicy':
proxy_client.set_missing_host_key_policy(paramiko.RejectPolicy())
elif self._missing_key_policy == 'WarningPolicy':
proxy_client.set_missing_host_key_policy(paramiko.WarningPolicy())
elif self._missing_key_policy == 'AutoAddPolicy':
proxy_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
proxy_client.connect(proxy['host'], **proxy_connargs)
except Exception as exc:
self.logger.error(
f"Error connecting to proxy '{proxy['host']}' through SSH: [{self.__class__.__name__}] {exc}, "
f'connect_args were: {proxy_connargs}'
)
self._close_proxies() # close all since we're going to start anew on the next open() (if any)
raise
connection_arguments['sock'] = proxy_client.get_transport().open_channel(
'direct-tcpip', (target['host'], target['port']), ('', 0)
)
self._proxies.append(proxy_client)
if proxycmdstring:
self._proxy = _DetachedProxyCommand(proxycmdstring)
connection_arguments['sock'] = self._proxy
try:
self._client.connect(self._machine, **connection_arguments)
except Exception as exc:
self.logger.error(
f"Error connecting to '{self._machine}' through SSH: " + f'[{self.__class__.__name__}] {exc}, ' +
f'connect_args were: {self._connect_args}'
)
self._close_proxies()
raise
# Open the SFTP channel, and handle error by directing customer to try another transport
try:
self._sftp = self._client.open_sftp()
except SSHException:
self._close_proxies()
raise InvalidOperation(
'Error in ssh transport plugin. This may be due to the remote computer not supporting SFTP. '
'Try setting it up with the aiida.transports:ssh_only transport from the aiida-sshonly plugin instead.'
)
self._is_open = True
# Set the current directory to a explicit path, and not to None
self._sftp.chdir(self._sftp.normalize('.'))
return self
[docs] def _close_proxies(self):
"""Close all proxy connections (proxy_jump and proxy_command)"""
# Paramiko only closes the channel when closing the main connection, but not the connection itself.
while self._proxies:
self._proxies.pop().close()
if self._proxy:
# Paramiko should close this automatically when closing the channel,
# but since the process is started in __init__this might not happen correctly.
self._proxy.close()
self._proxy = None
[docs] def close(self):
"""
Close the SFTP channel, and the SSHClient.
:todo: correctly manage exceptions
:raise aiida.common.InvalidOperation: if the channel is already open
"""
from aiida.common.exceptions import InvalidOperation
if not self._is_open:
raise InvalidOperation('Cannot close the transport: it is already closed')
self._sftp.close()
self._client.close()
self._close_proxies()
self._is_open = False
@property
def sshclient(self):
if not self._is_open:
raise TransportInternalError('Error, ssh method called for SshTransport without opening the channel first')
return self._client
@property
def sftp(self):
if not self._is_open:
raise TransportInternalError('Error, sftp method called for SshTransport without opening the channel first')
return self._sftp
[docs] def __str__(self):
"""
Return a useful string.
"""
conn_info = self._machine
try:
conn_info = f"{self._connect_args['username']}@{conn_info}"
except KeyError:
# No username explicitly defined: ignore
pass
try:
conn_info += f":{self._connect_args['port']}"
except KeyError:
# No port explicitly defined: ignore
pass
return f"{'OPEN' if self._is_open else 'CLOSED'} [{conn_info}]"
[docs] def chdir(self, path):
"""
Change directory of the SFTP session. Emulated internally by paramiko.
Differently from paramiko, if you pass None to chdir, nothing
happens and the cwd is unchanged.
"""
from paramiko.sftp import SFTPError
old_path = self.sftp.getcwd()
if path is not None:
try:
self.sftp.chdir(path)
except SFTPError as exc:
# e.args[0] is an error code. For instance,
# 20 is 'the object is not a directory'
# Here I just re-raise the message as IOError
raise IOError(exc.args[1])
# Paramiko already checked that path is a folder, otherwise I would
# have gotten an exception. Now, I want to check that I have read
# permissions in this folder (nothing is said on write permissions,
# though).
# Otherwise, if I do _exec_command_internal, that as a first operation
# cd's in a folder, I get a wrong retval, that is an unwanted behavior.
#
# Note: I don't store the result of the function; if I have no
# read permissions, this will raise an exception.
try:
self.stat('.')
except IOError as exc:
if 'Permission denied' in str(exc):
self.chdir(old_path)
raise IOError(str(exc))
[docs] def normalize(self, path='.'):
"""
Returns the normalized path (removing double slashes, etc...)
"""
return self.sftp.normalize(path)
[docs] def stat(self, path):
"""
Retrieve information about a file on the remote system. The return
value is an object whose attributes correspond to the attributes of
Python's ``stat`` structure as returned by ``os.stat``, except that it
contains fewer fields.
The fields supported are: ``st_mode``, ``st_size``, ``st_uid``,
``st_gid``, ``st_atime``, and ``st_mtime``.
:param str path: the filename to stat
:return: a `paramiko.sftp_attr.SFTPAttributes` object containing
attributes about the given file.
"""
return self.sftp.stat(path)
[docs] def lstat(self, path):
"""
Retrieve information about a file on the remote system, without
following symbolic links (shortcuts). This otherwise behaves exactly
the same as `stat`.
:param str path: the filename to stat
:return: a `paramiko.sftp_attr.SFTPAttributes` object containing
attributes about the given file.
"""
return self.sftp.lstat(path)
[docs] def getcwd(self):
"""
Return the current working directory for this SFTP session, as
emulated by paramiko. If no directory has been set with chdir,
this method will return None. But in __enter__ this is set explicitly,
so this should never happen within this class.
"""
return self.sftp.getcwd()
[docs] def makedirs(self, path, ignore_existing=False):
"""
Super-mkdir; create a leaf directory and all intermediate ones.
Works like mkdir, except that any intermediate path segment (not
just the rightmost) will be created if it does not exist.
NOTE: since os.path.split uses the separators as the host system
(that could be windows), I assume the remote computer is Linux-based
and use '/' as separators!
:param path: directory to create (string)
:param ignore_existing: if set to true, it doesn't give any error
if the leaf directory does already exist (bool)
:raise OSError: If the directory already exists.
"""
# check to avoid creation of empty dirs
path = os.path.normpath(path)
if path.startswith('/'):
to_create = path.strip().split('/')[1:]
this_dir = '/'
else:
to_create = path.strip().split('/')
this_dir = ''
for count, element in enumerate(to_create):
if count > 0:
this_dir += '/'
this_dir += element
if count + 1 == len(to_create) and self.isdir(this_dir) and ignore_existing:
return
if count + 1 == len(to_create) and self.isdir(this_dir) and not ignore_existing:
self.mkdir(this_dir)
if not self.isdir(this_dir):
self.mkdir(this_dir)
[docs] def mkdir(self, path, ignore_existing=False):
"""
Create a folder (directory) named path.
:param path: name of the folder to create
:param ignore_existing: if True, does not give any error if the directory
already exists
:raise OSError: If the directory already exists.
"""
if ignore_existing and self.isdir(path):
return
try:
self.sftp.mkdir(path)
except IOError as exc:
if os.path.isabs(path):
raise OSError(
"Error during mkdir of '{}', "
"maybe you don't have the permissions to do it, "
'or the directory already exists? ({})'.format(path, exc)
)
else:
raise OSError(
"Error during mkdir of '{}' from folder '{}', "
"maybe you don't have the permissions to do it, "
'or the directory already exists? ({})'.format(path, self.getcwd(), exc)
)
[docs] def rmtree(self, path):
"""
Remove a file or a directory at path, recursively
Flags used: -r: recursive copy; -f: force, makes the command non interactive;
:param path: remote path to delete
:raise IOError: if the rm execution failed.
"""
# Assuming linux rm command!
rm_exe = 'rm'
rm_flags = '-r -f'
# if in input I give an invalid object raise ValueError
if not path:
raise ValueError('Input to rmtree() must be a non empty string. ' + f'Found instead {path} as path')
command = f'{rm_exe} {rm_flags} {escape_for_bash(path)}'
retval, stdout, stderr = self.exec_command_wait_bytes(command)
if retval == 0:
if stderr.strip():
self.logger.warning(f'There was nonempty stderr in the rm command: {stderr}')
return True
self.logger.error(f"Problem executing rm. Exit code: {retval}, stdout: '{stdout}', stderr: '{stderr}'")
raise IOError(f'Error while executing rm. Exit code: {retval}')
[docs] def rmdir(self, path):
"""
Remove the folder named 'path' if empty.
"""
self.sftp.rmdir(path)
[docs] def chown(self, path, uid, gid):
"""
Change owner permissions of a file.
For now, this is not implemented for the SSH transport.
"""
raise NotImplementedError
[docs] def isdir(self, path):
"""
Return True if the given path is a directory, False otherwise.
Return False also if the path does not exist.
"""
# Return False on empty string (paramiko would map this to the local
# folder instead)
if not path:
return False
try:
return S_ISDIR(self.stat(path).st_mode)
except IOError as exc:
if getattr(exc, 'errno', None) == 2:
# errno=2 means path does not exist: I return False
return False
raise # Typically if I don't have permissions (errno=13)
[docs] def chmod(self, path, mode):
"""
Change permissions to path
:param path: path to file
:param mode: new permission bits (integer)
"""
if not path:
raise IOError('Input path is an empty argument.')
return self.sftp.chmod(path, mode)
[docs] @staticmethod
def _os_path_split_asunder(path):
"""
Used by makedirs. Takes path (a str)
and returns a list deconcatenating the path
"""
parts = []
while True:
newpath, tail = os.path.split(path)
if newpath == path:
assert not tail
if path:
parts.append(path)
break
parts.append(tail)
path = newpath
parts.reverse()
return parts
[docs] def put(self, localpath, remotepath, callback=None, dereference=True, overwrite=True, ignore_nonexisting=False): # pylint: disable=too-many-arguments,too-many-branches,arguments-differ
"""
Put a file or a folder from local to remote.
Redirects to putfile or puttree.
:param localpath: an (absolute) local path
:param remotepath: a remote path
:param dereference: follow symbolic links (boolean).
Default = True (default behaviour in paramiko). False is not implemented.
:param overwrite: if True overwrites files and folders (boolean).
Default = False.
:raise ValueError: if local path is invalid
:raise OSError: if the localpath does not exist
"""
if not dereference:
raise NotImplementedError
if not os.path.isabs(localpath):
raise ValueError('The localpath must be an absolute path')
if self.has_magic(localpath):
if self.has_magic(remotepath):
raise ValueError('Pathname patterns are not allowed in the destination')
# use the imported glob to analyze the path locally
to_copy_list = glob.glob(localpath)
rename_remote = False
if len(to_copy_list) > 1:
# I can't scp more than one file on a single file
if self.isfile(remotepath):
raise OSError('Remote destination is not a directory')
# I can't scp more than one file in a non existing directory
elif not self.path_exists(remotepath): # questo dovrebbe valere solo per file
raise OSError('Remote directory does not exist')
else: # the remote path is a directory
rename_remote = True
for file in to_copy_list:
if os.path.isfile(file):
if rename_remote: # copying more than one file in one directory
# here is the case isfile and more than one file
remotefile = os.path.join(remotepath, os.path.split(file)[1])
self.putfile(file, remotefile, callback, dereference, overwrite)
elif self.isdir(remotepath): # one file to copy in '.'
remotefile = os.path.join(remotepath, os.path.split(file)[1])
self.putfile(file, remotefile, callback, dereference, overwrite)
else: # one file to copy on one file
self.putfile(file, remotepath, callback, dereference, overwrite)
else:
self.puttree(file, remotepath, callback, dereference, overwrite)
else:
if os.path.isdir(localpath):
self.puttree(localpath, remotepath, callback, dereference, overwrite)
elif os.path.isfile(localpath):
if self.isdir(remotepath):
remote = os.path.join(remotepath, os.path.split(localpath)[1])
self.putfile(localpath, remote, callback, dereference, overwrite)
else:
self.putfile(localpath, remotepath, callback, dereference, overwrite)
else:
if not ignore_nonexisting:
raise OSError(f'The local path {localpath} does not exist')
[docs] def putfile(self, localpath, remotepath, callback=None, dereference=True, overwrite=True): # pylint: disable=arguments-differ
"""
Put a file from local to remote.
:param localpath: an (absolute) local path
:param remotepath: a remote path
:param overwrite: if True overwrites files and folders (boolean).
Default = True.
:raise ValueError: if local path is invalid
:raise OSError: if the localpath does not exist,
or unintentionally overwriting
"""
if not dereference:
raise NotImplementedError
if not os.path.isabs(localpath):
raise ValueError('The localpath must be an absolute path')
if self.isfile(remotepath) and not overwrite:
raise OSError('Destination already exists: not overwriting it')
return self.sftp.put(localpath, remotepath, callback=callback)
[docs] def puttree(self, localpath, remotepath, callback=None, dereference=True, overwrite=True): # pylint: disable=too-many-branches,arguments-differ,unused-argument
"""
Put a folder recursively from local to remote.
By default, overwrite.
:param localpath: an (absolute) local path
:param remotepath: a remote path
:param dereference: follow symbolic links (boolean)
Default = True (default behaviour in paramiko). False is not implemented.
:param overwrite: if True overwrites files and folders (boolean).
Default = True
:raise ValueError: if local path is invalid
:raise OSError: if the localpath does not exist, or trying to overwrite
:raise IOError: if remotepath is invalid
.. note:: setting dereference equal to True could cause infinite loops.
see os.walk() documentation
"""
if not dereference:
raise NotImplementedError
if not os.path.isabs(localpath):
raise ValueError('The localpath must be an absolute path')
if not os.path.exists(localpath):
raise OSError('The localpath does not exists')
if not os.path.isdir(localpath):
raise ValueError(f'Input localpath is not a folder: {localpath}')
if not remotepath:
raise IOError('remotepath must be a non empty string')
if self.path_exists(remotepath) and not overwrite:
raise OSError("Can't overwrite existing files")
if self.isfile(remotepath):
raise OSError('Cannot copy a directory into a file')
if not self.isdir(remotepath): # in this case copy things in the remotepath directly
self.mkdir(remotepath) # and make a directory at its place
else: # remotepath exists already: copy the folder inside of it!
remotepath = os.path.join(remotepath, os.path.split(localpath)[1])
self.mkdir(remotepath) # create a nested folder
for this_source in os.walk(localpath):
# Get the relative path
this_basename = os.path.relpath(path=this_source[0], start=localpath)
try:
self.stat(os.path.join(remotepath, this_basename))
except IOError as exc:
import errno
if exc.errno == errno.ENOENT: # Missing file
self.mkdir(os.path.join(remotepath, this_basename))
else:
raise
for this_file in this_source[2]:
this_local_file = os.path.join(localpath, this_basename, this_file)
this_remote_file = os.path.join(remotepath, this_basename, this_file)
self.putfile(this_local_file, this_remote_file)
[docs] def get(self, remotepath, localpath, callback=None, dereference=True, overwrite=True, ignore_nonexisting=False): # pylint: disable=too-many-branches,arguments-differ,too-many-arguments
"""
Get a file or folder from remote to local.
Redirects to getfile or gettree.
:param remotepath: a remote path
:param localpath: an (absolute) local path
:param dereference: follow symbolic links.
Default = True (default behaviour in paramiko).
False is not implemented.
:param overwrite: if True overwrites files and folders.
Default = False
:raise ValueError: if local path is invalid
:raise IOError: if the remotepath is not found
"""
if not dereference:
raise NotImplementedError
if not os.path.isabs(localpath):
raise ValueError('The localpath must be an absolute path')
if self.has_magic(remotepath):
if self.has_magic(localpath):
raise ValueError('Pathname patterns are not allowed in the destination')
# use the self glob to analyze the path remotely
to_copy_list = self.glob(remotepath)
rename_local = False
if len(to_copy_list) > 1:
# I can't scp more than one file on a single file
if os.path.isfile(localpath):
raise IOError('Remote destination is not a directory')
# I can't scp more than one file in a non existing directory
elif not os.path.exists(localpath): # this should hold only for files
raise OSError('Remote directory does not exist')
else: # the remote path is a directory
rename_local = True
for file in to_copy_list:
if self.isfile(file):
if rename_local: # copying more than one file in one directory
# here is the case isfile and more than one file
remote = os.path.join(localpath, os.path.split(file)[1])
self.getfile(file, remote, callback, dereference, overwrite)
else: # one file to copy on one file
self.getfile(file, localpath, callback, dereference, overwrite)
else:
self.gettree(file, localpath, callback, dereference, overwrite)
else:
if self.isdir(remotepath):
self.gettree(remotepath, localpath, callback, dereference, overwrite)
elif self.isfile(remotepath):
if os.path.isdir(localpath):
remote = os.path.join(localpath, os.path.split(remotepath)[1])
self.getfile(remotepath, remote, callback, dereference, overwrite)
else:
self.getfile(remotepath, localpath, callback, dereference, overwrite)
else:
if ignore_nonexisting:
pass
else:
raise IOError(f'The remote path {remotepath} does not exist')
[docs] def getfile(self, remotepath, localpath, callback=None, dereference=True, overwrite=True): # pylint: disable=arguments-differ
"""
Get a file from remote to local.
:param remotepath: a remote path
:param localpath: an (absolute) local path
:param overwrite: if True overwrites files and folders.
Default = False
:raise ValueError: if local path is invalid
:raise OSError: if unintentionally overwriting
"""
if not os.path.isabs(localpath):
raise ValueError('localpath must be an absolute path')
if os.path.isfile(localpath) and not overwrite:
raise OSError('Destination already exists: not overwriting it')
if not dereference:
raise NotImplementedError
# Workaround for bug #724 in paramiko -- remove localpath on IOError
try:
return self.sftp.get(remotepath, localpath, callback)
except IOError:
try:
os.remove(localpath)
except OSError:
pass
raise
[docs] def gettree(self, remotepath, localpath, callback=None, dereference=True, overwrite=True): # pylint: disable=arguments-differ,unused-argument
"""
Get a folder recursively from remote to local.
:param remotepath: a remote path
:param localpath: an (absolute) local path
:param dereference: follow symbolic links.
Default = True (default behaviour in paramiko).
False is not implemented.
:param overwrite: if True overwrites files and folders.
Default = False
:raise ValueError: if local path is invalid
:raise IOError: if the remotepath is not found
:raise OSError: if unintentionally overwriting
"""
if not dereference:
raise NotImplementedError
if not remotepath:
raise IOError('Remotepath must be a non empty string')
if not localpath:
raise ValueError('Localpaths must be a non empty string')
if not os.path.isabs(localpath):
raise ValueError('Localpaths must be an absolute path')
if not self.isdir(remotepath):
raise IOError(f'Input remotepath is not a folder: {localpath}')
if os.path.exists(localpath) and not overwrite:
raise OSError("Can't overwrite existing files")
if os.path.isfile(localpath):
raise OSError('Cannot copy a directory into a file')
if not os.path.isdir(localpath): # in this case copy things in the remotepath directly
os.makedirs(localpath, exist_ok=True) # and make a directory at its place
else: # localpath exists already: copy the folder inside of it!
localpath = os.path.join(localpath, os.path.split(remotepath)[1])
os.mkdir(localpath) # create a nested folder
item_list = self.listdir(remotepath)
dest = str(localpath)
for item in item_list:
item = str(item)
if self.isdir(os.path.join(remotepath, item)):
self.gettree(os.path.join(remotepath, item), os.path.join(dest, item))
else:
self.getfile(os.path.join(remotepath, item), os.path.join(dest, item))
[docs] def get_attribute(self, path):
"""
Returns the object Fileattribute, specified in aiida.transports
Receives in input the path of a given file.
"""
from aiida.transports.util import FileAttribute
paramiko_attr = self.lstat(path)
aiida_attr = FileAttribute()
# map the paramiko class into the aiida one
# note that paramiko object contains more informations than the aiida
for key in aiida_attr._valid_fields: # pylint: disable=protected-access
aiida_attr[key] = getattr(paramiko_attr, key)
return aiida_attr
[docs] def copyfile(self, remotesource, remotedestination, dereference=False):
return self.copy(remotesource, remotedestination, dereference)
[docs] def copytree(self, remotesource, remotedestination, dereference=False):
return self.copy(remotesource, remotedestination, dereference, recursive=True)
[docs] def copy(self, remotesource, remotedestination, dereference=False, recursive=True):
"""
Copy a file or a directory from remote source to remote destination.
Flags used: ``-r``: recursive copy; ``-f``: force, makes the command non interactive;
``-L`` follows symbolic links
:param remotesource: file to copy from
:param remotedestination: file to copy to
:param dereference: if True, copy content instead of copying the symlinks only
Default = False.
:param recursive: if True copy directories recursively, otherwise only copy the specified file(s)
:type recursive: bool
:raise IOError: if the cp execution failed.
.. note:: setting dereference equal to True could cause infinite loops.
"""
# In the majority of cases, we should deal with linux cp commands
cp_flags = '-f'
if recursive:
cp_flags += ' -r'
# For the moment, this is hardcoded. May become a parameter
cp_exe = 'cp'
# To evaluate if we also want -p: preserves mode,ownership and timestamp
if dereference:
# use -L; --dereference is not supported on mac
cp_flags += ' -L'
# if in input I give an invalid object raise ValueError
if not remotesource:
raise ValueError(
'Input to copy() must be a non empty string. ' + f'Found instead {remotesource} as remotesource'
)
if not remotedestination:
raise ValueError(
'Input to copy() must be a non empty string. ' +
f'Found instead {remotedestination} as remotedestination'
)
if self.has_magic(remotedestination):
raise ValueError('Pathname patterns are not allowed in the destination')
if self.has_magic(remotesource):
to_copy_list = self.glob(remotesource)
if len(to_copy_list) > 1:
if not self.path_exists(remotedestination) or self.isfile(remotedestination):
raise OSError("Can't copy more than one file in the same destination file")
for file in to_copy_list:
self._exec_cp(cp_exe, cp_flags, file, remotedestination)
else:
self._exec_cp(cp_exe, cp_flags, remotesource, remotedestination)
[docs] def _exec_cp(self, cp_exe, cp_flags, src, dst):
"""Execute the ``cp`` command on the remote machine."""
# to simplify writing the above copy function
command = f'{cp_exe} {cp_flags} {escape_for_bash(src)} {escape_for_bash(dst)}'
retval, stdout, stderr = self.exec_command_wait_bytes(command)
if retval == 0:
if stderr.strip():
self.logger.warning(f'There was nonempty stderr in the cp command: {stderr}')
else:
self.logger.error(
"Problem executing cp. Exit code: {}, stdout: '{}', "
"stderr: '{}', command: '{}'".format(retval, stdout, stderr, command)
)
raise IOError(
'Error while executing cp. Exit code: {}, '
"stdout: '{}', stderr: '{}', "
"command: '{}'".format(retval, stdout, stderr, command)
)
[docs] @staticmethod
def _local_listdir(path, pattern=None):
"""
Acts on the local folder, for the rest, same as listdir
"""
if not pattern:
return os.listdir(path)
if path.startswith('/'): # always this is the case in the local case
base_dir = path
else:
base_dir = os.path.join(os.getcwd(), path)
filtered_list = glob.glob(os.path.join(base_dir, pattern))
if not base_dir.endswith(os.sep):
base_dir += os.sep
return [re.sub(base_dir, '', i) for i in filtered_list]
[docs] def listdir(self, path='.', pattern=None):
"""
Get the list of files at path.
:param path: default = '.'
:param pattern: returns the list of files matching pattern.
Unix only. (Use to emulate ``ls *`` for example)
"""
if not pattern:
return self.sftp.listdir(path)
if path.startswith('/'):
base_dir = path
else:
base_dir = os.path.join(self.getcwd(), path)
filtered_list = self.glob(os.path.join(base_dir, pattern))
if not base_dir.endswith('/'):
base_dir += '/'
return [re.sub(base_dir, '', i) for i in filtered_list]
[docs] def remove(self, path):
"""
Remove a single file at 'path'
"""
return self.sftp.remove(path)
[docs] def rename(self, oldpath, newpath):
"""
Rename a file or folder from oldpath to newpath.
:param str oldpath: existing name of the file or folder
:param str newpath: new name for the file or folder
:raises IOError: if oldpath/newpath is not found
:raises ValueError: if sroldpathc/newpath is not a valid string
"""
if not oldpath:
raise ValueError(f'Source {oldpath} is not a valid string')
if not newpath:
raise ValueError(f'Destination {newpath} is not a valid string')
if not self.isfile(oldpath):
if not self.isdir(oldpath):
raise IOError(f'Source {oldpath} does not exist')
if not self.isfile(newpath):
if not self.isdir(newpath):
raise IOError(f'Destination {newpath} does not exist')
return self.sftp.rename(oldpath, newpath)
[docs] def isfile(self, path):
"""
Return True if the given path is a file, False otherwise.
Return False also if the path does not exist.
"""
# This should not be needed for files, since an empty string should
# be mapped by paramiko to the local directory - which is not a file -
# but this is just to be sure
if not path:
return False
try:
self.logger.debug(
f"stat for path '{path}' ('{self.normalize(path)}'): {self.stat(path)} [{self.stat(path).st_mode}]"
)
return S_ISREG(self.stat(path).st_mode)
except IOError as exc:
if getattr(exc, 'errno', None) == 2:
# errno=2 means path does not exist: I return False
return False
raise # Typically if I don't have permissions (errno=13)
[docs] def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1): # pylint: disable=arguments-differ
"""
Executes the specified command in bash login shell.
Before the command is executed, changes directory to the current
working directory as returned by self.getcwd().
For executing commands and waiting for them to finish, use
exec_command_wait.
:param command: the command to execute. The command is assumed to be
already escaped using :py:func:`aiida.common.escaping.escape_for_bash`.
:param combine_stderr: (default False) if True, combine stdout and
stderr on the same buffer (i.e., stdout).
Note: If combine_stderr is True, stderr will always be empty.
:param bufsize: same meaning of the one used by paramiko.
:return: a tuple with (stdin, stdout, stderr, channel),
where stdin, stdout and stderr behave as file-like objects,
plus the methods provided by paramiko, and channel is a
paramiko.Channel object.
"""
channel = self.sshclient.get_transport().open_session()
channel.set_combine_stderr(combine_stderr)
if self.getcwd() is not None:
escaped_folder = escape_for_bash(self.getcwd())
command_to_execute = (f'cd {escaped_folder} && ( {command} )')
else:
command_to_execute = command
self.logger.debug(f'Command to be executed: {command_to_execute[:self._MAX_EXEC_COMMAND_LOG_SIZE]}')
# Note: The default shell will eat one level of escaping, while
# 'bash -l -c ...' will eat another. Thus, we need to escape again.
bash_commmand = self._bash_command_str + '-c '
channel.exec_command(bash_commmand + escape_for_bash(command_to_execute))
stdin = channel.makefile('wb', bufsize)
stdout = channel.makefile('rb', bufsize)
stderr = channel.makefile_stderr('rb', bufsize)
return stdin, stdout, stderr, channel
[docs] def exec_command_wait_bytes(self, command, stdin=None, combine_stderr=False, bufsize=-1): # pylint: disable=arguments-differ, too-many-branches
"""
Executes the specified command and waits for it to finish.
:param command: the command to execute
:param stdin: (optional,default=None) can be a string or a
file-like object.
:param combine_stderr: (optional, default=False) see docstring of
self._exec_command_internal()
:param bufsize: same meaning of paramiko.
:return: a tuple with (return_value, stdout, stderr) where stdout and stderr
are both bytes and the return_value is an int.
"""
import socket
import time
ssh_stdin, stdout, stderr, channel = self._exec_command_internal(command, combine_stderr, bufsize=bufsize)
if stdin is not None:
if isinstance(stdin, str):
filelike_stdin = io.StringIO(stdin)
elif isinstance(stdin, bytes):
filelike_stdin = io.BytesIO(stdin)
elif isinstance(stdin, (io.BufferedIOBase, io.TextIOBase)):
# It seems both StringIO and BytesIO work correctly when doing ssh_stdin.write(line)?
# (The ChannelFile is opened with mode 'b', but until now it always has been a StringIO)
filelike_stdin = stdin
else:
raise ValueError('You can only pass strings, bytes, BytesIO or StringIO objects')
for line in filelike_stdin:
ssh_stdin.write(line)
# I flush and close them anyway; important to call shutdown_write
# to avoid hangouts
ssh_stdin.flush()
ssh_stdin.channel.shutdown_write()
# Now I get the output
stdout_bytes = []
stderr_bytes = []
# 100kB buffer (note that this should be smaller than the window size of paramiko)
# Also, apparently if the data is coming slowly, the read() command will not unlock even for
# times much larger than the timeout. Therefore we don't want to have very large buffers otherwise
# you risk that a lot of output is sent to both stdout and stderr, and stderr goes beyond the
# window size and blocks.
# Note that this is different than the bufsize of paramiko.
internal_bufsize = 100 * 1024
# Set a small timeout on the channels, so that if we get data from both
# stderr and stdout, and the connection is slow, we interleave the receive and don't hang
# NOTE: Timeouts and sleep time below, as well as the internal_bufsize above, have been benchmarked
# to try to optimize the overall throughput. I could get ~100MB/s on a localhost via ssh (and 3x slower
# if compression is enabled).
# It's important to mention that, for speed benchmarks, it's important to disable compression
# in the SSH transport settings, as it will cap the max speed.
stdout.channel.settimeout(0.01)
stderr.channel.settimeout(0.01) # Maybe redundant, as this could be the same channel.
while True:
chunk_exists = False
if stdout.channel.recv_ready(): # True means that the next .read call will at least receive 1 byte
chunk_exists = True
try:
piece = stdout.read(internal_bufsize)
stdout_bytes.append(piece)
except socket.timeout:
# There was a timeout: I continue as there should still be data
pass
if stderr.channel.recv_stderr_ready(): # True means that the next .read call will at least receive 1 byte
chunk_exists = True
try:
piece = stderr.read(internal_bufsize)
stderr_bytes.append(piece)
except socket.timeout:
# There was a timeout: I continue as there should still be data
pass
# If chunk_exists, there is data (either already read and put in the std*_bytes lists, or
# still in the buffer because of a timeout). I need to loop.
# Otherwise, there is no data in the buffers, and I enter this block.
if not chunk_exists:
# Both channels have no data in the buffer
if channel.exit_status_ready():
# The remote execution is over
# I think that in some corner cases there might still be some data,
# in case the data arrived between the previous calls and this check.
# So we do a final read. Since the execution is over, I think all data is in the buffers,
# so we can just read the whole buffer without loops
stdout_bytes.append(stdout.read())
stderr_bytes.append(stderr.read())
# And we go out of the `while True` loop
break
# The exit status is not ready:
# I just put a small sleep to avoid infinite fast loops when data
# is not available on a slow connection, and loop
time.sleep(0.01)
# I get the return code (blocking)
# However, if I am here, the exit status is ready so this should be returning very quickly
retval = channel.recv_exit_status()
return (retval, b''.join(stdout_bytes), b''.join(stderr_bytes))
[docs] def gotocomputer_command(self, remotedir):
"""
Specific gotocomputer string to connect to a given remote computer via
ssh and directly go to the calculation folder.
"""
further_params = []
if 'username' in self._connect_args:
further_params.append(f"-l {escape_for_bash(self._connect_args['username'])}")
if self._connect_args.get('port'):
further_params.append(f"-p {self._connect_args['port']}")
if self._connect_args.get('key_filename'):
further_params.append(f"-i {escape_for_bash(self._connect_args['key_filename'])}")
if self._connect_args.get('proxy_jump'):
further_params.append(f"-o ProxyJump={escape_for_bash(self._connect_args['proxy_jump'])}")
if self._connect_args.get('proxy_command'):
further_params.append(f"-o ProxyCommand={escape_for_bash(self._connect_args['proxy_command'])}")
further_params_str = ' '.join(further_params)
connect_string = self._gotocomputer_string(remotedir)
cmd = f'ssh -t {self._machine} {further_params_str} {connect_string}'
return cmd
[docs] def _symlink(self, source, dest):
"""
Wrap SFTP symlink call without breaking API
:param source: source of link
:param dest: link to create
"""
self.sftp.symlink(source, dest)
[docs] def symlink(self, remotesource, remotedestination):
"""
Create a symbolic link between the remote source and the remote
destination.
:param remotesource: remote source. Can contain a pattern.
:param remotedestination: remote destination
"""
# paramiko gives some errors if path is starting with '.'
source = os.path.normpath(remotesource)
dest = os.path.normpath(remotedestination)
if self.has_magic(source):
if self.has_magic(dest):
# if there are patterns in dest, I don't know which name to assign
raise ValueError('Remotedestination cannot have patterns')
# find all files matching pattern
for this_source in self.glob(source):
# create the name of the link: take the last part of the path
this_dest = os.path.join(remotedestination, os.path.split(this_source)[-1])
self._symlink(this_source, this_dest)
else:
self._symlink(source, dest)
[docs] def path_exists(self, path):
"""
Check if path exists
"""
import errno
try:
self.stat(path)
except IOError as exc:
if exc.errno == errno.ENOENT:
return False
raise
else:
return True