# SPDX-FileCopyrightText: 2024 Geoffrey Lentner
# SPDX-License-Identifier: Apache-2.0
"""SSH-based cluster implementation."""
# type annotations
from __future__ import annotations
from typing import Type, List, Iterable, Tuple, IO, Final
# standard libs
import re
import sys
import time
import shlex
import secrets
from subprocess import Popen
# external libs
from cmdkit.config import ConfigurationError, Namespace
# internal libs
from hypershell.core.config import config, blame
from hypershell.core.thread import Thread
from hypershell.core.logging import Logger, HOSTNAME
from hypershell.core.template import DEFAULT_TEMPLATE
from hypershell.client import DEFAULT_DELAY, DEFAULT_SIGNALWAIT
from hypershell.submit import DEFAULT_BUNDLEWAIT
from hypershell.server import ServerThread, DEFAULT_PORT, DEFAULT_BUNDLESIZE, DEFAULT_ATTEMPTS
# public interface
__all__ = ['run_ssh', 'SSHCluster', 'NodeList', 'DEFAULT_REMOTE_EXE']
# initialize logger
log = Logger.with_name('hypershell.cluster')
# NOTE: retain old name for remote executable (for now)
DEFAULT_REMOTE_EXE: Final[str] = 'hyper-shell'
"""Default remote executable name."""
[docs]
def run_ssh(**options) -> None:
"""
Run cluster with remote clients via SSH until completion.
All function arguments are forwarded directly into the
:class:`~hypershell.cluster.ssh.SSHCluster` thread.
Example:
>>> from hypershell.cluster import run_ssh
>>> run_ssh(
... nodelist=['a00.cluster', 'a01.cluster', 'a02.cluster'],
... restart_mode=True, max_retries=2, eager=True,
... client_timeout=600, task_timeout=300, capture=True
... )
See Also:
- :class:`~hypershell.cluster.ssh.SSHCluster`
"""
thread = SSHCluster.new(**options)
try:
thread.join()
except Exception:
thread.stop()
raise
[docs]
class SSHCluster(Thread):
"""
Run server with individual external ssh clients.
Args:
source (Iterable[str], optional):
Any iterable of command-line tasks.
A new `source` results in a :class:`~hypershell.submit.SubmitThread` populating
either the database or the queue directly depending on `in_memory`.
nodelist (List[str], required):
List of hostnames for launching clients.
See also: :class:`~hypershell.cluster.ssh.NodeList`.
num_tasks (int, optional):
Number of parallel task executor threads.
See :const:`~hypershell.client.DEFAULT_NUM_TASKS`.
template (str, optional):
Template command pattern.
See :const:`~hypershell.client.DEFAULT_TEMPLATE`.
bundlesize (int optional):
Size of task bundles returned to server.
See :const:`~hypershell.server.DEFAULT_BUNDLESIZE`.
bundlewait (int optional):
Waiting period in seconds before forcing return of task bundle to server.
See :const:`~hypershell.submit.DEFAULT_BUNDLEWAIT`.
bind (tuple, optional):
Bind address for server with port number (default: 0.0.0.0).
See :const:`~hypershell.server.DEFAULT_PORT`.
launcher (str, optional):
Launcher program used to bring up clients on other hosts.
Defaults to 'ssh'. Use `launcher_args` to provide command options.
launcher_args (List[str], optional):
Additional command-line arguments for launcher program.
remote_exe (str, optional):
Program name or path for remote executable.
See :const:`~hypershell.cluster.ssh.DEFAULT_REMOTE_EXE`.
export_env (bool, optional):
If enabled, embed local configuration as environment variables
and forward with client launch command to remote hosts.
in_memory (bool, optional):
If True, revert to basic in-memory queue.
no_confirm (bool, optional):
Disable client confirmation of tasks received.
forever_mode (bool, optional):
Regardless of `source`, if enabled schedule forever.
Conflicts with `restart_mode` and `in_memory`. Default is `False`.
restart_mode (bool, optional):
If `source` is empty, this option allows for the server to continue
with scheduling from the database until complete.
Conflicts with `in_memory`. Default is `False`.
max_retries (int, optional):
Number of allowed task retries.
See :const:`~hypershell.server.DEFAULT_ATTEMPTS`.
eager (bool, optional):
When enabled tasks are retried immediately ahead scheduling new tasks.
See :const:`~hypershell.server.DEFAULT_EAGER_MODE`.
redirect_failures (IO, optional):
Open file-like object to write failed tasks.
delay_start (float, optional):
Delay in seconds before connecting to server.
See :const:`~hypershell.client.DEFAULT_DELAY`.
capture (bool, optional):
Isolate task <stdout> and <stderr> in discrete files.
Defaults to `False`.
client_timeout (int, optional):
Timeout in seconds before disconnecting from server.
By default, the client waits for server tor request disconnect.
task_timeout (int, optional):
Task-level walltime limit in seconds.
By default, the client waits indefinitely on tasks.
task_signalwait (int, optional):
Signal escalation waiting period in seconds on task timeout.
See :const:`~hypershell.client.DEFAULT_SIGNALWAIT`.
Example:
>>> from hypershell.cluster import SSHCluster
>>> cluster = SSHCluster.new(
... nodelist=['a00.cluster', 'a01.cluster', 'a02.cluster'],
... restart_mode=True, max_retries=2, eager=True,
... client_timeout=600, task_timeout=300, capture=True
... )
>>> cluster.join()
See Also:
- :class:`~hypershell.server.ServerThread`
- :meth:`~hypershell.cluster.ssh.run_ssh`
"""
server: ServerThread
clients: List[Popen]
client_argv: List[List[str]]
def __init__(self: SSHCluster,
source: Iterable[str] = None,
nodelist: List[str] = None,
num_tasks: int = 1,
template: str = DEFAULT_TEMPLATE,
bundlesize: int = DEFAULT_BUNDLESIZE,
bundlewait: int = DEFAULT_BUNDLEWAIT,
bind: Tuple[str, int] = ('0.0.0.0', DEFAULT_PORT),
launcher: str = 'ssh',
launcher_args: List[str] = None,
remote_exe: str = DEFAULT_REMOTE_EXE,
export_env: bool = False,
in_memory: bool = False,
no_confirm: bool = False,
forever_mode: bool = False,
restart_mode: bool = False,
max_retries: int = DEFAULT_ATTEMPTS,
eager: bool = False,
redirect_failures: IO = None,
delay_start: float = DEFAULT_DELAY,
capture: bool = False,
client_timeout: int = None,
task_timeout: int = None,
task_signalwait: int = DEFAULT_SIGNALWAIT) -> None:
"""Initialize server and client threads."""
if nodelist is None:
raise AttributeError('Expected nodelist')
auth = secrets.token_hex(64)
self.server = ServerThread(source=source,
bundlesize=bundlesize,
bundlewait=bundlewait,
in_memory=in_memory,
no_confirm=no_confirm,
address=bind,
auth=auth,
max_retries=max_retries,
eager=eager,
forever_mode=forever_mode,
restart_mode=restart_mode,
redirect_failures=redirect_failures)
launcher = shlex.split(launcher)
launcher_env = shlex.split('' if not export_env else compile_env())
if launcher_args is None:
launcher_args = shlex.split(config.ssh.get('args', ''))
client_args = []
if capture:
client_args.append('--capture')
if no_confirm:
client_args.append('--no-confirm')
if client_timeout is not None:
client_args.extend(['-T', str(client_timeout)])
if task_timeout is not None:
client_args.extend(['-W', str(task_timeout)])
self.client_argv = [
[*launcher, *launcher_args, host, *launcher_env, remote_exe, 'client', '-H', HOSTNAME,
'-p', str(bind[1]), '-N', str(num_tasks), '-b', str(bundlesize), '-w', str(bundlewait),
'-t', f'\'{template}\'', '-k', auth, '-d', str(delay_start),
'-S', str(task_signalwait), *client_args]
for host in nodelist
]
super().__init__(name='hypershell-cluster')
def run_with_exceptions(self: SSHCluster) -> None:
"""Start child threads, wait."""
self.server.start()
while not self.server.queue.ready:
time.sleep(0.1)
self.clients = []
for argv in self.client_argv:
log.debug(f'Launching client: {argv}')
self.clients.append(Popen(argv, stdout=sys.stdout, stderr=sys.stderr))
for client in self.clients:
client.wait()
self.server.join()
[docs]
def stop(self: SSHCluster, wait: bool = False, timeout: int = None) -> None:
"""Stop child threads before main thread."""
self.server.stop(wait=wait, timeout=timeout)
for client in self.clients:
client.terminate()
super().stop(wait=wait, timeout=timeout)
class NodeList(list):
"""A list of hostnames."""
group_pattern: re.Pattern = re.compile(r',(?![^[]*])')
range_pattern: re.Pattern = re.compile(r'\[((\d+|\d+-\d+)(,(\d+|\d+-\d+))*)]')
@classmethod
def from_cmdline(cls: Type[NodeList], arg: str = None) -> NodeList:
"""Smart initialization via some command-line `arg`."""
if not arg:
return cls.from_config()
if ',' in arg or cls.range_pattern.search(arg):
return cls.from_pattern(arg)
else:
return cls([arg, ])
@classmethod
def from_config(cls: Type[NodeList], groupname: str = None) -> NodeList:
"""Load list of hostnames from configuration file."""
if 'ssh' not in config:
raise ConfigurationError('No `ssh` section found in configuration')
if 'nodelist' not in config.ssh:
raise ConfigurationError('No `ssh.nodelist` section found in configuration')
label = blame(config, 'ssh', 'nodelist')
nodelist = config.ssh.nodelist
if groupname is None:
if isinstance(nodelist, str):
return cls.from_pattern(nodelist)
elif isinstance(nodelist, list):
return cls(nodelist)
elif isinstance(nodelist, dict):
raise ConfigurationError(f'SSH group unspecified but multiple groups in `ssh.nodelist` ({label})')
else:
raise ConfigurationError(f'Expected list for `ssh.nodelist` ({label})')
if isinstance(nodelist, dict):
nodelist = nodelist.get(groupname, None)
if not nodelist:
raise ConfigurationError(f'No list \'{groupname}\' found in `ssh.nodelist` section ({label})')
elif isinstance(nodelist, str):
return cls.from_pattern(nodelist)
elif isinstance(nodelist, list):
return cls(nodelist)
else:
raise ConfigurationError(f'Expected list for `ssh.nodelist.{groupname}` ({label})')
else:
raise ConfigurationError(f'Expected either list or section for `ssh.nodelist` ({label})')
@classmethod
def from_pattern(cls: Type[NodeList], pattern: str) -> NodeList:
"""Expand a `pattern` to multiple hostnames."""
pattern = pattern.strip(',')
if match := cls.group_pattern.search(pattern):
idx = match.start()
return cls(cls.from_pattern(pattern[0:idx]) + cls.from_pattern(pattern[idx+1:]))
else:
if match := cls.range_pattern.search(pattern):
range_spec = match.group()
prefix, suffix = pattern.split(range_spec)
segments = cls.expand_pattern(range_spec)
return cls([f'{prefix}{segment}{suffix}' for segment in segments])
else:
return cls([pattern, ])
@staticmethod
def expand_pattern(spec: str) -> List[str]:
"""Take a range spec (e.g., [4,6-8]) and expand to [4,6,7,8]."""
spec = spec.strip('[]')
result = []
for group in spec.split(','):
if '-' in group:
start_chars, stop_chars = group.strip().split('-')
apply_padding = str if start_chars[0] != '0' else lambda s: str(s).zfill(len(start_chars))
start, stop = int(start_chars), int(stop_chars)
if stop < start:
start, stop = stop, start
result.extend([apply_padding(value) for value in range(start, stop + 1)])
else:
result.extend([group, ])
return result
def compile_env() -> str:
"""Build environment variable argument expansion for remote client launch command."""
return ' '.join([f'{key}="{value}"' for key, value in Namespace.from_env('HYPERSHELL').items()])