# SPDX-FileCopyrightText: 2024 Geoffrey Lentner
# SPDX-License-Identifier: Apache-2.0
"""
Connect to server and run tasks.
Example:
>>> from hypershell.client import run_client
>>> run_client(num_tasks=4, address=('<IP ADDRESS>', 8080), auth='<secret>')
Embed a `ClientThread` in your application directly. Call `stop()` to stop early.
Clients cannot connect to a remote machine unless you set the server's `bind` address
to 0.0.0.0 (as opposed to localhost which is the default).
Example:
>>> from hypershell.client import ClientThread
>>> client_thread = ClientThread.new(num_tasks=4, address=('<IP ADDRESS>', 8080), auth='<secret>')
Note:
In order for the `ClientThread` to actively monitor the state set by `stop` and
halt execution (a requirement because of how CPython does threading), the implementation
uses a finite state machine. *You should not instantiate this machine directly*.
Warning:
Because the `ClientThread` checks state actively to decide whether to halt, it may take
some few moments before it shutsdown on its own. If your main program exits however,
the thread will be stopped regardless because it runs as a `daemon`.
"""
# type annotations
from __future__ import annotations
from typing import List, Tuple, Optional, Callable, Dict, IO, Type, Final
from types import TracebackType
# standard libs
import os
import sys
import time
import json
import random
import functools
from enum import Enum
from datetime import datetime, timedelta
from queue import Queue, Empty as QueueEmpty, Full as QueueFull
from subprocess import Popen, TimeoutExpired
from socket import gaierror
from dataclasses import dataclass
from multiprocessing import AuthenticationError, cpu_count
# external libs
from cmdkit.app import Application, exit_status
from cmdkit.cli import Interface, ArgumentError
from cmdkit.config import Namespace
# internal libs
from hypershell.data.model import Task
from hypershell.core.heartbeat import Heartbeat, ClientState
from hypershell.core.platform import default_path
from hypershell.core.config import default, config, load_task_env, SSH_GROUPS
from hypershell.core.fsm import State, StateMachine
from hypershell.core.thread import Thread
from hypershell.core.signal import check_signal, SIGNAL_MAP, SIGUSR1, SIGUSR2, SIGINT
from hypershell.core.queue import QueueClient, QueueConfig
from hypershell.core.logging import HOSTNAME, INSTANCE, Logger
from hypershell.core.template import Template, DEFAULT_TEMPLATE
from hypershell.core.exceptions import (handle_exception, handle_disconnect,
handle_address_unknown, HostAddressInfo, get_shared_exception_mapping)
# public interface
__all__ = ['run_client', 'ClientThread', 'ClientApp', 'ClientInfo',
'DEFAULT_BUNDLESIZE', 'DEFAULT_BUNDLEWAIT', 'DEFAULT_TEMPLATE',
'DEFAULT_NUM_TASKS', 'DEFAULT_DELAY', 'DEFAULT_SIGNALWAIT',
'DEFAULT_HEARTRATE', 'DEFAULT_HOST', 'DEFAULT_PORT', 'DEFAULT_AUTH',
'set_client_standalone']
# initialize logger
log = Logger.with_name(__name__)
# NOTE:
# The UNIX signal facility works on stand-alone server/client, but when running a LocalCluster with
# a client as a local thread, the USR1/USR2 signals prevent clients from sending the proper finalization
# messages. This flag is set by LocalCluster to prevent greedy client-side shutdown behavior.
CLIENT_STANDALONE_MODE: bool = True
def set_client_standalone(mode: bool) -> None:
"""Set global flag to prevent greedy shutdown from USR1/USR2 signals."""
global CLIENT_STANDALONE_MODE
CLIENT_STANDALONE_MODE = mode
@dataclass
class ClientInfo:
"""Client instance ID/hostname and task ID mapping."""
client_id: str
client_host: str
task_ids: List[str]
@classmethod
def from_dict(cls: Type[ClientInfo], data: dict) -> ClientInfo:
"""Initialize from existing dictionary."""
return cls(**data)
def to_dict(self: ClientInfo) -> dict:
"""Export to dictionary."""
return {'client_id': self.client_id, 'client_host': self.client_host, 'task_ids': self.task_ids}
def pack(self: ClientInfo) -> bytes:
"""Serialize data."""
return json.dumps(self.to_dict()).encode('utf-8')
@classmethod
def unpack(cls: Type[ClientInfo], data: bytes) -> ClientInfo:
"""Deserialize from raw `data`."""
return cls.from_dict(json.loads(data.decode('utf-8')))
@classmethod
def from_tasks(cls: Type[ClientInfo], tasks: List[Task]) -> ClientInfo:
"""Initialize from list of existing Task instances."""
return cls(client_id=INSTANCE, client_host=HOSTNAME,
task_ids=[task.id for task in tasks])
def transpose(self: ClientInfo) -> List[Dict[str, str]]:
"""Represent as list of dicts for database update."""
return [{'id': task_id, 'client_id': self.client_id, 'client_host': self.client_host}
for task_id in self.task_ids]
class SchedulerState(State, Enum):
"""Finite states for scheduler."""
START = 0
GET_REMOTE = 1
UNPACK = 2
PUT_CONFIRM = 3
POP_TASK = 4
PUT_LOCAL = 5
FINAL = 6
HALT = 7
class ClientScheduler(StateMachine):
"""Receive task bundles from server and schedule locally."""
queue: QueueClient
local: Queue[Optional[Task]]
bundle: List[bytes]
client_info: Optional[bytes]
no_confirm: bool
timeout: Optional[timedelta]
previous_received: datetime
task: Task
tasks: List[Task]
state = SchedulerState.START
states = SchedulerState
def __init__(self: ClientScheduler,
queue: QueueClient,
local: Queue[Optional[Task]],
no_confirm: bool = False,
timeout: int = None) -> None:
"""Assign remote queue client and local task queue."""
self.queue = queue
self.local = local
self.bundle = []
self.tasks = []
self.client_info = None
self.no_confirm = no_confirm
self.timeout = None if not timeout else timedelta(seconds=timeout)
self.previous_received = datetime.now()
@functools.cached_property
def actions(self: ClientScheduler) -> Dict[SchedulerState, Callable[[], SchedulerState]]:
return {
SchedulerState.START: self.start,
SchedulerState.GET_REMOTE: self.get_remote,
SchedulerState.UNPACK: self.unpack_bundle,
SchedulerState.PUT_CONFIRM: self.put_confirm,
SchedulerState.POP_TASK: self.pop_task,
SchedulerState.PUT_LOCAL: self.put_local,
SchedulerState.FINAL: self.finalize,
}
def start(self: ClientScheduler) -> SchedulerState:
"""Jump to GET_REMOTE state."""
timeout_label = self.timeout or 'no'
log.debug(f'Started (scheduler: {timeout_label} timeout)')
return SchedulerState.GET_REMOTE
def get_remote(self: ClientScheduler) -> SchedulerState:
"""Get the next task bundle from the server."""
if check_signal() in (SIGUSR1, SIGUSR2) and CLIENT_STANDALONE_MODE:
log.warning(f'Signal interrupt ({SIGNAL_MAP[check_signal()]})')
return SchedulerState.FINAL
try:
self.bundle = self.queue.scheduled.get(timeout=2)
self.queue.scheduled.task_done()
self.previous_received = datetime.now()
if self.bundle is not None:
log.debug(f'Received {len(self.bundle)} tasks ({HOSTNAME}: {INSTANCE})')
return SchedulerState.UNPACK
else:
log.debug('Disconnect received')
return SchedulerState.FINAL
except QueueEmpty:
waited = datetime.now() - self.previous_received
if self.timeout is None or waited < self.timeout:
return SchedulerState.GET_REMOTE
else:
log.debug(f'Timeout reached ({waited})')
return SchedulerState.FINAL
def unpack_bundle(self: ClientScheduler) -> SchedulerState:
"""Unpack latest bundle of tasks."""
self.tasks = [Task.unpack(data) for data in self.bundle]
if not self.no_confirm:
self.client_info = ClientInfo.from_tasks(self.tasks).pack()
return SchedulerState.PUT_CONFIRM
else:
return SchedulerState.POP_TASK
def put_confirm(self: ClientScheduler) -> SchedulerState:
"""Put confirmation details back on remote queue."""
try:
self.queue.confirmed.put(self.client_info, timeout=2)
log.debug(f'Confirmed {len(self.tasks)} tasks ({HOSTNAME}: {INSTANCE})')
return SchedulerState.POP_TASK
except QueueFull:
return SchedulerState.PUT_CONFIRM
def pop_task(self: ClientScheduler) -> SchedulerState:
"""Pop next task off current task list."""
try:
self.task = self.tasks.pop(0)
return SchedulerState.PUT_LOCAL
except IndexError:
return SchedulerState.GET_REMOTE
def put_local(self: ClientScheduler) -> SchedulerState:
"""Put latest task on the local task queue."""
try:
self.local.put(self.task, timeout=1)
return SchedulerState.POP_TASK
except QueueFull:
return SchedulerState.PUT_LOCAL
@staticmethod
def finalize() -> SchedulerState:
"""Stop scheduler."""
log.debug('Done (scheduler)')
return SchedulerState.HALT
class ClientSchedulerThread(Thread):
"""Run client scheduler in dedicated thread."""
def __init__(self: ClientSchedulerThread,
queue: QueueClient,
local: Queue[Optional[bytes]],
no_confirm: bool = False,
timeout: int = None) -> None:
"""Initialize machine."""
super().__init__(name='hypershell-client-scheduler')
self.machine = ClientScheduler(queue=queue, local=local, no_confirm=no_confirm, timeout=timeout)
def run_with_exceptions(self: ClientSchedulerThread) -> None:
"""Run machine."""
self.machine.run()
def stop(self: ClientSchedulerThread, wait: bool = False, timeout: int = None) -> None:
"""Stop machine."""
log.warning('Stopping (scheduler)')
self.machine.halt()
super().stop(wait=wait, timeout=timeout)
DEFAULT_BUNDLESIZE: Final[int] = default.client.bundlesize
"""Default size of task bundles."""
DEFAULT_BUNDLEWAIT: Final[int] = default.client.bundlewait
"""Default waiting period before forcing task bundle push."""
class CollectorState(State, Enum):
"""Finite states of collector."""
START = 0
GET_LOCAL = 1
CHECK_BUNDLE = 2
PACK_BUNDLE = 3
PUT_REMOTE = 4
FINAL = 5
HALT = 6
class ClientCollector(StateMachine):
"""Collect finished tasks and bundle for outgoing queue."""
tasks: List[Task]
bundle: List[bytes]
queue: QueueClient
local: Queue[Optional[Task]]
bundlesize: int
bundlewait: int
previous_send: datetime
state = CollectorState.START
states = CollectorState
def __init__(self: ClientCollector, queue: QueueClient, local: Queue[Optional[Task]],
bundlesize: int = DEFAULT_BUNDLESIZE, bundlewait: int = DEFAULT_BUNDLEWAIT) -> None:
"""Collect tasks from local queue of finished tasks and push them to the server."""
self.tasks = []
self.bundle = []
self.local = local
self.queue = queue
self.bundlesize = bundlesize
self.bundlewait = bundlewait
@functools.cached_property
def actions(self: ClientCollector) -> Dict[CollectorState, Callable[[], CollectorState]]:
return {
CollectorState.START: self.start,
CollectorState.GET_LOCAL: self.get_local,
CollectorState.CHECK_BUNDLE: self.check_bundle,
CollectorState.PACK_BUNDLE: self.pack_bundle,
CollectorState.PUT_REMOTE: self.put_remote,
CollectorState.FINAL: self.finalize,
}
def start(self: ClientCollector) -> CollectorState:
"""Jump to GET_LOCAL state."""
log.debug('Started (collector)')
self.previous_send = datetime.now()
return CollectorState.GET_LOCAL
def get_local(self: ClientCollector) -> CollectorState:
"""Get the next task from the local completed task queue."""
try:
task = self.local.get(timeout=1)
self.local.task_done()
if task:
self.tasks.append(task)
return CollectorState.CHECK_BUNDLE
else:
return CollectorState.FINAL
except QueueEmpty:
return CollectorState.CHECK_BUNDLE
def check_bundle(self: ClientCollector) -> CollectorState:
"""Check state of task bundle and proceed with return if necessary."""
wait_time = (datetime.now() - self.previous_send)
since_last = wait_time.total_seconds()
if len(self.tasks) >= self.bundlesize:
log.trace(f'Bundle size reached ({len(self.tasks)} tasks)')
return CollectorState.PACK_BUNDLE
elif since_last >= self.bundlewait:
log.trace(f'Bundle wait exceeded ({wait_time})')
return CollectorState.PACK_BUNDLE
else:
return CollectorState.GET_LOCAL
def pack_bundle(self: ClientCollector) -> CollectorState:
"""Pack tasks into bundle before pushing back to server."""
self.bundle = [task.pack() for task in self.tasks]
return CollectorState.PUT_REMOTE
def put_remote(self: ClientCollector) -> CollectorState:
"""Push out bundle of completed tasks."""
try:
if self.bundle:
self.queue.completed.put(self.bundle, timeout=2)
log.trace(f'Bundle returned ({len(self.bundle)} tasks)')
self.tasks.clear()
self.bundle.clear()
self.previous_send = datetime.now()
else:
log.trace('Bundle empty')
return CollectorState.GET_LOCAL
except QueueFull:
return CollectorState.PUT_REMOTE
def finalize(self: ClientCollector) -> CollectorState:
"""Push out any remaining tasks and halt."""
self.put_remote()
log.debug('Done (collector)')
return CollectorState.HALT
class ClientCollectorThread(Thread):
"""Run client collector within dedicated thread."""
def __init__(self: ClientCollectorThread, queue: QueueClient, local: Queue[Optional[bytes]],
bundlesize: int = DEFAULT_BUNDLESIZE, bundlewait: int = DEFAULT_BUNDLEWAIT) -> None:
"""Initialize machine."""
super().__init__(name='hypershell-client-collector')
self.machine = ClientCollector(queue=queue, local=local, bundlesize=bundlesize, bundlewait=bundlewait)
def run_with_exceptions(self: ClientCollectorThread) -> None:
"""Run machine."""
self.machine.run()
def stop(self: ClientCollectorThread, wait: bool = False, timeout: int = None) -> None:
"""Stop machine."""
log.warning('Stopping (collector)')
self.machine.halt()
super().stop(wait=wait, timeout=timeout)
DEFAULT_SIGNALWAIT: Final[int] = default.task.signalwait
"""Default signal escalation wait period in seconds."""
def task_env(task: Task) -> Dict[str, str]:
"""Build environment dictionary for the given `task`."""
task_data = task.to_json()
try:
# We have to flatten tag data separately, otherwise we'd have TASK_TAG='{...}'
tag_data = Namespace(task_data.pop('tag')).to_env().flatten(prefix='TASK_TAG')
except Exception: # noqa: any exception
tag_data = {}
return {
**os.environ,
**load_task_env(),
**Namespace.from_dict(task_data).to_env().flatten(prefix='TASK'),
**tag_data,
'TASK_CWD': config.task.cwd,
'TASK_OUTPATH': os.path.join(default_path.lib, 'task', f'{task.id}.out'),
'TASK_ERRPATH': os.path.join(default_path.lib, 'task', f'{task.id}.err'),
}
class TaskState(State, Enum):
"""Finite states for task executor."""
START = 0
GET_LOCAL = 1
CREATE_TASK = 2
START_TASK = 3
WAIT_TASK = 4
CHECK_TASK = 5
WAIT_SIGNAL = 6
STOP_TASK = 7
TERM_TASK = 8
KILL_TASK = 9
PUT_LOCAL = 10
FINAL = 11
HALT = 12
class TaskExecutor(StateMachine):
"""Run tasks locally."""
id: int
task: Task
process: Popen
template: Template
redirect_output: IO
redirect_errors: IO
capture: bool
elapsed: timedelta
timeout: Optional[int]
signalwait: int
stop_requested: Optional[datetime]
attempted_sigint: bool
attempted_sigterm: bool
attempted_sigkill: bool
inbound: Queue[Optional[Task]]
outbound: Queue[Optional[Task]]
state = TaskState.START
states = TaskState
def __init__(self: TaskExecutor,
id: int,
inbound: Queue[Optional[Task]],
outbound: Queue[Optional[Task]],
template: str = DEFAULT_TEMPLATE,
redirect_output: IO = None,
redirect_errors: IO = None,
capture: bool = False,
timeout: int = None,
signalwait: int = DEFAULT_SIGNALWAIT) -> None:
"""Initialize task executor."""
self.id = id
self.template = Template(template)
self.inbound = inbound
self.outbound = outbound
self.redirect_output = redirect_output or sys.stdout
self.redirect_errors = redirect_errors or sys.stderr
self.capture = capture
self.timeout = timeout
self.signalwait = signalwait
@functools.cached_property
def actions(self: TaskExecutor) -> Dict[TaskState, Callable[[], TaskState]]:
return {
TaskState.START: self.start,
TaskState.GET_LOCAL: self.get_local,
TaskState.CREATE_TASK: self.create_task,
TaskState.START_TASK: self.start_task,
TaskState.WAIT_TASK: self.wait_task,
TaskState.CHECK_TASK: self.check_task,
TaskState.WAIT_SIGNAL: self.wait_signal,
TaskState.STOP_TASK: self.stop_task,
TaskState.TERM_TASK: self.term_task,
TaskState.KILL_TASK: self.kill_task,
TaskState.PUT_LOCAL: self.put_local,
TaskState.FINAL: self.finalize,
}
def start(self: TaskExecutor) -> TaskState:
"""Jump to GET_LOCAL state."""
log.debug(f'Started (executor-{self.id})')
return TaskState.GET_LOCAL
def get_local(self: TaskExecutor) -> TaskState:
"""Get the next task from the local queue of new tasks."""
try:
self.task = self.inbound.get(timeout=1)
self.inbound.task_done()
return TaskState.CREATE_TASK if self.task else TaskState.FINAL
except QueueEmpty:
return TaskState.GET_LOCAL
def create_task(self: TaskExecutor) -> TaskState:
"""Expand template and initialize task command-line."""
try:
self.task.client_id = INSTANCE
self.task.client_host = HOSTNAME
self.task.command = self.template.expand(self.task.args)
return TaskState.START_TASK
except Exception as error:
log.error(f'{error.__class__.__name__}: {error}')
self.task.start_time = datetime.now().astimezone()
self.task.completion_time = datetime.now().astimezone()
self.task.exit_status = -1
return TaskState.PUT_LOCAL
def start_task(self: TaskExecutor) -> TaskState:
"""Start current task locally."""
# NOTE: enforce tz aware submit_time (in case of sqlite backend)
self.task.start_time = datetime.now().astimezone()
self.task.waited = int((self.task.start_time - self.task.submit_time.astimezone()).total_seconds())
env = task_env(self.task)
if self.capture:
self.task.outpath = env['TASK_OUTPATH']
self.task.errpath = env['TASK_ERRPATH']
self.redirect_output = open(self.task.outpath, mode='w')
self.redirect_errors = open(self.task.errpath, mode='w')
self.stop_requested = None
self.attempted_sigint = False
self.attempted_sigterm = False
self.attempted_sigkill = False
self.process = Popen(self.task.command, shell=True,
stdout=self.redirect_output, stderr=self.redirect_errors,
cwd=config.task.cwd, env=env)
log.info(f'Running task ({self.task.id})')
log.debug(f'Running task ({self.task.id})[{self.process.pid}]: {self.task.command}')
return TaskState.WAIT_TASK
def wait_task(self: TaskExecutor) -> TaskState:
"""Wait for current task to complete."""
try:
self.task.exit_status = self.process.wait(timeout=1)
self.task.completion_time = datetime.now().astimezone()
self.task.duration = int((self.task.completion_time - self.task.start_time).total_seconds())
log.debug(f'Completed task ({self.task.id})')
if self.capture:
self.redirect_output.close()
self.redirect_errors.close()
return TaskState.PUT_LOCAL
except TimeoutExpired:
# Only display time elapsed to the nearest second
self.elapsed = timedelta(seconds=round((datetime.now().astimezone() -
self.task.start_time).total_seconds()))
log.trace(f'Waiting on task ({self.task.id}: {self.elapsed})')
if self.stop_requested:
return TaskState.WAIT_SIGNAL
else:
return TaskState.CHECK_TASK
def check_task(self: TaskExecutor) -> TaskState:
"""Check for timeout or interrupts."""
if check_signal() == SIGUSR2: # NOTE: regardless of CLIENT_STANDALONE_MODE
log.warning(f'Signal interrupt (SIGUSR2: executor-{self.id})')
self.stop_requested = datetime.now()
return TaskState.WAIT_SIGNAL
elif self.timeout is None or self.elapsed.total_seconds() < self.timeout:
return TaskState.WAIT_TASK
else:
log.warning(f'Task exceeded walltime limit ({self.elapsed})')
self.stop_requested = datetime.now()
return TaskState.WAIT_SIGNAL
def wait_signal(self: TaskExecutor) -> TaskState:
"""Wait on interrupts."""
if self.attempted_sigint is False:
return TaskState.STOP_TASK
elif (datetime.now() - self.stop_requested).total_seconds() < 1 * self.signalwait:
return TaskState.WAIT_TASK
elif self.attempted_sigterm is False:
log.error(f'Interrupt ignored ({self.task.id})')
return TaskState.TERM_TASK
elif (datetime.now() - self.stop_requested).total_seconds() < 2 * self.signalwait:
return TaskState.WAIT_TASK
elif self.attempted_sigkill is False:
log.error(f'Terminate ignored ({self.task.id})')
return TaskState.KILL_TASK
elif (datetime.now() - self.stop_requested).total_seconds() < 3 * self.signalwait:
return TaskState.WAIT_TASK
else:
log.critical(f'Process ignored SIGKILL ({self.task.id}: {self.process.pid})')
log.critical(f'Shutting down executor ({self.id})')
return TaskState.FINAL
def stop_task(self: TaskExecutor) -> TaskState:
"""Send SIGINT to task process."""
log.debug(f'Sending SIGINT ({self.task.id}: {self.process.pid})')
self.process.send_signal(SIGINT)
self.attempted_sigint = True
return TaskState.WAIT_TASK
def term_task(self: TaskExecutor) -> TaskState:
"""Send SIGTERM to task process."""
log.debug(f'Sending SIGTERM ({self.task.id}: {self.process.pid})')
self.process.terminate()
self.attempted_sigterm = True
return TaskState.WAIT_TASK
def kill_task(self: TaskExecutor) -> TaskState:
"""Send SIGKILL or halt executor if ignored."""
log.debug(f'Sending SIGKILL ({self.task.id}: {self.process.pid})')
self.process.kill()
self.attempted_sigkill = True
return TaskState.WAIT_TASK
def put_local(self: TaskExecutor) -> TaskState:
"""Put completed task on outbound queue."""
try:
self.outbound.put(self.task, timeout=1)
return TaskState.GET_LOCAL
except QueueFull:
return TaskState.PUT_LOCAL
def finalize(self: TaskExecutor) -> TaskState:
"""Push out any remaining tasks and halt."""
log.debug(f'Done (executor-{self.id})')
if self.redirect_output is not sys.stdout:
self.redirect_output.close()
if self.redirect_errors is not sys.stderr:
self.redirect_errors.close()
return TaskState.HALT
class TaskThread(Thread):
"""Run task executor within dedicated thread."""
id: int
def __init__(self: TaskThread,
id: int,
inbound: Queue[Optional[str]],
outbound: Queue[Optional[str]],
template: str = DEFAULT_TEMPLATE,
capture: bool = False,
redirect_output: IO = None,
redirect_errors: IO = None,
timeout: int = None,
signalwait: int = DEFAULT_SIGNALWAIT) -> None:
"""Initialize task executor."""
self.id = id
super().__init__(name=f'hypershell-executor-{id}')
self.machine = TaskExecutor(id=id, inbound=inbound, outbound=outbound, template=template,
redirect_output=redirect_output, redirect_errors=redirect_errors,
capture=capture, timeout=timeout, signalwait=signalwait)
def run_with_exceptions(self: TaskThread) -> None:
"""Run machine."""
self.machine.run()
def stop(self: TaskThread, wait: bool = False, timeout: int = None) -> None:
"""Stop machine."""
log.warning(f'Stopping (executor-{self.id})')
self.machine.halt()
super().stop(wait=wait, timeout=timeout)
class HeartbeatState(State, Enum):
"""Finite states for heartbeat machine."""
START = 0
SUBMIT = 1
WAIT = 2
FINAL = 3
HALT = 4
DEFAULT_HEARTRATE: Final[int] = default.client.heartrate
"""Period in seconds to wait between heartbeats."""
class ClientHeartbeat(StateMachine):
"""Register heartbeats with remote server."""
queue: QueueClient
heartrate: timedelta
previous: datetime = None
no_wait: bool = False
client_state: ClientState = ClientState.RUNNING
state = HeartbeatState.START
states = HeartbeatState
def __init__(self: ClientHeartbeat, queue: QueueClient, heartrate: int = DEFAULT_HEARTRATE) -> None:
"""Initialize heartbeat machine."""
self.queue = queue
self.previous = datetime.now()
self.heartrate = timedelta(seconds=heartrate)
@functools.cached_property
def actions(self: ClientHeartbeat) -> Dict[HeartbeatState, Callable[[], HeartbeatState]]:
return {
HeartbeatState.START: self.start,
HeartbeatState.SUBMIT: self.submit,
HeartbeatState.WAIT: self.wait,
HeartbeatState.FINAL: self.finalize,
}
@staticmethod
def start() -> HeartbeatState:
"""Jump to SUBMIT state."""
log.debug(f'Started (heartbeat)')
return HeartbeatState.SUBMIT
def submit(self: ClientHeartbeat) -> HeartbeatState:
"""Publish new heartbeat to remote queue."""
try:
client_state = self.client_state # atomic
heartbeat = Heartbeat.new(state=client_state)
self.queue.heartbeat.put(heartbeat.pack(), timeout=2)
if client_state is ClientState.RUNNING:
log.trace(f'Heartbeat - running ({heartbeat.host}: {heartbeat.uuid})')
return HeartbeatState.WAIT
else:
log.trace(f'Heartbeat - final ({heartbeat.host}: {heartbeat.uuid})')
return HeartbeatState.FINAL
except QueueEmpty:
return HeartbeatState.SUBMIT
def wait(self: ClientHeartbeat) -> HeartbeatState:
"""Wait until next needed heartbeat."""
if self.no_wait:
return HeartbeatState.SUBMIT
now = datetime.now()
if (now - self.previous) < self.heartrate:
time.sleep(1)
return HeartbeatState.WAIT
else:
self.previous = now
return HeartbeatState.SUBMIT
@staticmethod
def finalize() -> HeartbeatState:
"""Stop heartbeats."""
log.debug(f'Done (heartbeat)')
return HeartbeatState.HALT
class ClientHeartbeatThread(Thread):
"""Run heartbeat machine within dedicated thread."""
def __init__(self: ClientHeartbeatThread, queue: QueueClient, heartrate: int = DEFAULT_HEARTRATE) -> None:
"""Initialize heartbeat machine."""
super().__init__(name=f'hypershell-heartbeat')
self.machine = ClientHeartbeat(queue=queue, heartrate=heartrate)
def run_with_exceptions(self: ClientHeartbeatThread) -> None:
"""Run machine."""
self.machine.run()
def signal_finished(self: ClientHeartbeatThread) -> None:
"""Set client state to communicate completion."""
self.machine.client_state = ClientState.FINISHED
self.machine.no_wait = True
def stop(self: ClientHeartbeatThread, wait: bool = False, timeout: int = None) -> None:
"""Stop machine."""
log.warning('Stopping (heartbeat)')
self.machine.halt()
super().stop(wait=wait, timeout=timeout)
DEFAULT_NUM_TASKS: Final[int] = 1
"""Default number of task executors per client."""
# We do not delay connecting to the server unless explicitly specified
DEFAULT_DELAY: Final[int] = 0
"""Default delay in seconds on client startup."""
DEFAULT_HOST: Final[str] = QueueConfig.host
"""Default host for server connection."""
DEFAULT_PORT: Final[int] = QueueConfig.port
"""Default port for server connection."""
DEFAULT_AUTH: Final[str] = QueueConfig.auth
"""Default authentication key for server (**DO NOT USE THIS**)."""
[docs]
class ClientThread(Thread):
"""
Run client within dedicated thread.
Run until either disconnect requested from server or `client_timeout` reached.
Args:
num_tasks (int, optional):
Number of parallel task executor threads.
See :const:`DEFAULT_NUM_TASKS`.
bundlesize (int optional):
Size of task bundles returned to server.
See :const:`DEFAULT_BUNDLESIZE`.
bundlewait (int optional):
Waiting period in seconds before forcing return of task bundle to server.
See :const:`DEFAULT_BUNDLEWAIT`.
address (tuple, optional):
Server host address for server with port number.
See :const:`DEFAULT_HOST` and :const:`DEFAULT_PORT`.
auth (str, optional):
Server authentication key.
See :const:`DEFAULT_AUTH`.
template (str, optional):
Template command pattern. See :const:`DEFAULT_TEMPLATE`.
redirect_output (IO, optional):
Optional file-like object for <stdout> redirect.
redirect_errors (IO, optional):
Optional file-like object for <stderr> redirect.
heartrate (int, optional):
Period in seconds to wait between heartbeats.
See :const:`DEFAULT_HEARTRATE`,
capture (bool, optional):
Isolate task <stdout> and <stderr> in discrete files.
Defaults to `False`.
delay_start (float, optional):
Delay in seconds before connecting to server.
See :const:`DEFAULT_DELAY`.
no_confirm (bool, optional):
Disable client confirmation of tasks received.
client_timeout (int, optional):
Timeout in seconds before disconnecting from server.
By default, the client waits for server tor request disconnect.
task_timeout (int, optional):
Task-level walltime limit in seconds.
By default, the client waits indefinitely on tasks.
task_signalwait (int, optional):
Signal escalation waiting period in seconds on task timeout.
See :const:`DEFAULT_SIGNALWAIT`.
Example:
>>> from hypershell.client import ClientThread
>>> client = ClientThread.new(num_tasks=16, address=('localhost', 54321),
... auth='my-secret-key', capture=True)
>>> client.join()
See Also:
- :meth:`run_client`
"""
client: QueueClient
num_tasks: int
delay_start: float
no_confirm: bool
inbound: Queue[Optional[Task]]
outbound: Queue[Optional[Task]]
scheduler: ClientSchedulerThread
collector: ClientCollectorThread
executors: List[TaskThread]
def __init__(self: ClientThread,
num_tasks: int = DEFAULT_NUM_TASKS,
bundlesize: int = DEFAULT_BUNDLESIZE,
bundlewait: int = DEFAULT_BUNDLEWAIT,
address: Tuple[str, int] = (DEFAULT_HOST, DEFAULT_PORT),
auth: str = DEFAULT_AUTH,
template: str = DEFAULT_TEMPLATE,
redirect_output: IO = None,
redirect_errors: IO = None,
heartrate: int = DEFAULT_HEARTRATE,
capture: bool = False,
delay_start: float = DEFAULT_DELAY,
no_confirm: bool = False,
client_timeout: int = None,
task_timeout: int = None,
task_signalwait: int = DEFAULT_SIGNALWAIT) -> None:
"""Initialize queue manager and child threads."""
super().__init__(name='hypershell-client')
self.num_tasks = num_tasks
self.delay_start = delay_start
self.no_confirm = no_confirm
self.client = QueueClient(config=QueueConfig(host=address[0], port=address[1], auth=auth))
self.inbound = Queue(maxsize=bundlesize)
self.outbound = Queue(maxsize=bundlesize)
self.scheduler = ClientSchedulerThread(queue=self.client, local=self.inbound,
no_confirm=no_confirm, timeout=client_timeout)
self.heartbeat = ClientHeartbeatThread(queue=self.client, heartrate=heartrate)
self.collector = ClientCollectorThread(queue=self.client, local=self.outbound,
bundlesize=bundlesize, bundlewait=bundlewait)
self.executors = [TaskThread(id=count+1,
inbound=self.inbound, outbound=self.outbound,
redirect_output=redirect_output, redirect_errors=redirect_errors,
template=template, capture=capture, timeout=task_timeout,
signalwait=task_signalwait)
for count in range(num_tasks)]
def run_with_exceptions(self: ClientThread) -> None:
"""Start child threads, wait."""
log.debug(f'Started ({self.num_tasks} executors)')
self.wait_start()
with self.client:
self.start_threads()
self.wait_scheduler()
self.wait_executors()
self.wait_collector()
self.wait_heartbeat()
log.debug('Done')
def wait_start(self: ClientThread) -> None:
"""Wait constant period or random interval."""
if self.delay_start == 0:
return
if self.delay_start > 0:
log.debug(f'Waiting ({self.delay_start} seconds)')
time.sleep(self.delay_start)
else:
delay = random.uniform(0, -1 * self.delay_start)
log.debug(f'Waiting random ({delay:.1f} seconds)')
time.sleep(delay)
def start_threads(self: ClientThread) -> None:
"""Start child threads."""
self.scheduler.start()
self.collector.start()
self.heartbeat.start()
for executor in self.executors:
executor.start()
def wait_scheduler(self: ClientThread) -> None:
"""Wait for all tasks to be completed."""
log.trace('Waiting (scheduler)')
self.scheduler.join()
def wait_collector(self: ClientThread) -> None:
"""Signal collector to halt."""
log.trace('Waiting (collector)')
self.outbound.put(None)
self.collector.join()
def wait_executors(self: ClientThread) -> None:
"""Send disconnect signal to each task executor thread."""
for _ in self.executors:
self.inbound.put(None) # signal executors to shut down
for thread in self.executors:
log.trace(f'Waiting (executor-{thread.id})')
thread.join()
def wait_heartbeat(self: ClientThread) -> None:
"""Signal HALT on heartbeat."""
log.trace('Waiting (heartbeat)')
self.heartbeat.signal_finished()
self.heartbeat.join()
[docs]
def stop(self: ClientThread, wait: bool = False, timeout: int = None) -> None:
"""Stop child threads before main thread."""
log.warning('Stopping')
self.scheduler.stop(wait=wait, timeout=timeout)
self.collector.stop(wait=wait, timeout=timeout)
super().stop(wait=wait, timeout=timeout)
[docs]
def run_client(num_tasks: int = DEFAULT_NUM_TASKS,
bundlesize: int = DEFAULT_BUNDLESIZE,
bundlewait: int = DEFAULT_BUNDLEWAIT,
address: Tuple[str, int] = (DEFAULT_HOST, DEFAULT_PORT),
auth: str = DEFAULT_AUTH,
template: str = DEFAULT_TEMPLATE,
redirect_output: IO = None,
redirect_errors: IO = None,
capture: bool = False,
heartrate: int = DEFAULT_HEARTRATE,
delay_start: float = DEFAULT_DELAY,
no_confirm: bool = False,
client_timeout: int = None,
task_timeout: int = None,
task_signalwait: int = DEFAULT_SIGNALWAIT) -> None:
"""
Run client until disconnect signal received or `client_timeout` reached.
Args:
num_tasks (int, optional):
Number of parallel task executor threads.
See :const:`DEFAULT_NUM_TASKS`.
bundlesize (int optional):
Size of task bundles returned to server.
See :const:`DEFAULT_BUNDLESIZE`.
bundlewait (int optional):
Waiting period in seconds before forcing return of task bundle to server.
See :const:`DEFAULT_BUNDLEWAIT`.
address (tuple, optional):
Server host address for server with port number.
See :const:`DEFAULT_HOST` and :const:`DEFAULT_PORT`.
auth (str, optional):
Server authentication key.
See :const:`DEFAULT_AUTH`.
template (str, optional):
Template command pattern. See :const:`DEFAULT_TEMPLATE`.
redirect_output (IO, optional):
Optional file-like object for <stdout> redirect.
redirect_errors (IO, optional):
Optional file-like object for <stderr> redirect.
heartrate (int, optional):
Period in seconds to wait between heartbeats.
See :const:`DEFAULT_HEARTRATE`,
capture (bool, optional):
Isolate task <stdout> and <stderr> in discrete files.
Defaults to `False`.
delay_start (float, optional):
Delay in seconds before connecting to server.
See :const:`DEFAULT_DELAY`.
no_confirm (bool, optional):
Disable client confirmation of tasks received.
client_timeout (int, optional):
Timeout in seconds before disconnecting from server.
By default, the client waits for server tor request disconnect.
task_timeout (int, optional):
Task-level walltime limit in seconds.
By default, the client waits indefinitely on tasks.
task_signalwait (int, optional):
Signal escalation waiting period in seconds on task timeout.
See :const:`DEFAULT_SIGNALWAIT`.
Example:
>>> from hypershell.client import run_client
>>> run_client(num_tasks=16, address=('localhost', 54321),
... auth='my-secret-key', capture=True)
See Also:
- :meth:`ClientThread`
"""
thread = ClientThread.new(num_tasks=num_tasks,
bundlesize=bundlesize,
bundlewait=bundlewait,
address=address,
auth=auth,
template=template,
capture=capture,
redirect_output=redirect_output,
redirect_errors=redirect_errors,
heartrate=heartrate,
delay_start=delay_start,
no_confirm=no_confirm,
client_timeout=client_timeout,
task_timeout=task_timeout,
task_signalwait=task_signalwait)
try:
thread.join()
except Exception:
thread.stop()
raise
APP_NAME = 'hs client'
APP_USAGE = f"""\
Usage:
hs client [-h] [-N NUM] [-t CMD] [-b SIZE] [-w SEC] [-H ADDR] [-p PORT] [-k KEY]
[--capture | [-o PATH] [-e PATH]] [--no-confirm] [-d SEC] [-T SEC] [-W SEC] [-S SEC]
Launch client directly, run tasks in parallel.\
"""
APP_HELP = f"""\
{APP_USAGE}
Tasks are pulled off of the shared queue in bundles from the server and run
locally within the same shell as the client. By default the bundle size is one,
meaning that at small scales there is greater responsiveness. It is recommended
to coordinate these parameters to be the same as the server.
Options:
-N, --num-tasks NUM Number of tasks to run in parallel (default: {DEFAULT_NUM_TASKS}).
-t, --template CMD Command-line template pattern (default: "{DEFAULT_TEMPLATE}").
-b, --bundlesize SIZE Bundle size for finished tasks (default: {DEFAULT_BUNDLESIZE}).
-w, --bundlewait SEC Seconds to wait before flushing tasks (default: {DEFAULT_BUNDLEWAIT}).
-H, --host ADDR Hostname for server.
-p, --port NUM Port number for server.
-k, --auth KEY Cryptographic key to connect to server.
-d, --delay-start SEC Seconds to wait before start-up (default: {DEFAULT_DELAY}).
--no-confirm Disable confirmation of task bundle received.
-o, --output PATH Redirect task output (default: <stdout>).
-e, --errors PATH Redirect task errors (default: <stderr>).
-c, --capture Capture individual task <stdout> and <stderr>.
-T, --timeout SEC Automatically shutdown if no tasks received (default: never).
-W, --task-timeout SEC Task-level walltime limit (default: none).
-S, --signalwait SEC Task-level signal escalation wait period (default: {DEFAULT_SIGNALWAIT}).
-h, --help Show this message and exit.\
"""
class ClientApp(Application):
"""Run individual client directly."""
name = APP_NAME
interface = Interface(APP_NAME, APP_USAGE, APP_HELP)
num_tasks: int = DEFAULT_NUM_TASKS
interface.add_argument('-N', '--num-tasks', type=int, default=num_tasks)
host: str = config.server.bind
interface.add_argument('-H', '--host', default=host)
port: int = config.server.port
interface.add_argument('-p', '--port', type=int, default=port)
auth: str = config.server.auth
interface.add_argument('-k', '--auth', default=auth)
template: str = DEFAULT_TEMPLATE
interface.add_argument('-t', '--template', default=template)
bundlesize: int = config.submit.bundlesize
interface.add_argument('-b', '--bundlesize', type=int, default=bundlesize)
bundlewait: int = config.submit.bundlewait
interface.add_argument('-w', '--bundlewait', type=int, default=bundlewait)
delay_start: float = DEFAULT_DELAY
interface.add_argument('-d', '--delay-start', type=float, default=delay_start)
task_timeout: int = config.task.timeout
client_timeout: int = config.client.timeout
interface.add_argument('-T', '--timeout', type=int, default=client_timeout, dest='client_timeout')
interface.add_argument('-W', '--task-timeout', type=int, default=task_timeout, dest='task_timeout')
task_signalwait: int = config.task.signalwait
interface.add_argument('-S', '--task-signalwait', type=int, default=task_signalwait, dest='task_signalwait')
no_confirm: bool = False
interface.add_argument('--no-confirm', action='store_true')
output_path: str = None
errors_path: str = None
interface.add_argument('-o', '--output', default=None, dest='output_path')
interface.add_argument('-e', '--errors', default=None, dest='errors_path')
capture: bool = False
interface.add_argument('-c', '--capture', action='store_true')
# Hidden options used as helpers for shell completion
interface.add_argument('--available-cores', action='version', version=str(cpu_count()))
interface.add_argument('--available-ssh-groups', action='version', version='\n'.join(SSH_GROUPS))
exceptions = {
EOFError: functools.partial(handle_disconnect, logger=log),
ConnectionResetError: functools.partial(handle_disconnect, logger=log),
ConnectionRefusedError: functools.partial(handle_exception, logger=log, status=exit_status.runtime_error),
AuthenticationError: functools.partial(handle_exception, logger=log, status=exit_status.runtime_error),
HostAddressInfo: functools.partial(handle_address_unknown, logger=log, status=exit_status.runtime_error),
**get_shared_exception_mapping(__name__),
}
def run(self: ClientApp) -> None:
"""Run client."""
try:
self.check_args()
run_client(num_tasks=self.num_tasks,
bundlesize=self.bundlesize,
bundlewait=self.bundlewait,
address=(self.host, self.port),
auth=self.auth,
template=self.template,
redirect_output=self.output_stream,
redirect_errors=self.errors_stream,
capture=self.capture,
delay_start=self.delay_start,
no_confirm=self.no_confirm,
heartrate=config.client.heartrate,
client_timeout=self.client_timeout,
task_timeout=self.task_timeout,
task_signalwait=self.task_signalwait)
except gaierror:
raise HostAddressInfo(f'Could not resolve host \'{self.host}\'')
def check_args(self: ClientApp) -> None:
"""Check for logical errors in command-line arguments."""
if self.capture and (self.output_path or self.errors_path):
raise ArgumentError('Cannot specify --capture with either --output or --errors')
if self.client_timeout is not None and self.client_timeout <= 0:
raise ArgumentError('Client --timeout should be positive integer')
if self.task_timeout is not None and self.task_timeout <= 0:
raise ArgumentError('Client --task-timeout should be positive integer')
@functools.cached_property
def output_stream(self: ClientApp) -> IO:
"""IO stream for task outputs."""
return sys.stdout if not self.output_path else open(self.output_path, mode='w')
@functools.cached_property
def errors_stream(self: ClientApp) -> IO:
"""IO stream for task errors."""
return sys.stderr if not self.errors_path else open(self.errors_path, mode='w')
def __exit__(self: ClientApp,
exc_type: Optional[Type[Exception]],
exc_val: Optional[Exception],
exc_tb: Optional[TracebackType]) -> None:
"""Close IO streams if necessary."""
if self.output_stream is not sys.stdout:
self.output_stream.close()
if self.errors_stream is not sys.stderr:
self.errors_stream.close()