Source code for cdiutils.pipeline.base

import logging
import os
import signal
import subprocess
import sys
import textwrap
import time
from abc import ABC
from functools import wraps
from typing import Callable

import numpy as np
import yaml

from cdiutils.plot.formatting import update_plot_params

# Define a custom log level for JOB
JOB_LOG_LEVEL = 25  # Between INFO (20) and WARNING (30)
logging.addLevelName(JOB_LOG_LEVEL, "JOB")


# Create a method to log at the JOB level
[docs] def job(self, message, *args, **kwargs): if self.isEnabledFor(JOB_LOG_LEVEL): self._log(JOB_LOG_LEVEL, message, args, **kwargs)
logging.Logger.job = job
[docs] class LoggerWriter: """ Custom stream redirecting stdout to logger in real-time. Captures print statements and routes them through the logging system with optional line wrapping at 79 characters. Args: logger (logging.Logger): target logger instance. level (int): logging level (e.g., logging.INFO). wrap (bool): enable line wrapping at 79 chars. Defaults to True. """
[docs] def __init__( self, logger: logging.Logger, level: int, wrap: bool = True ) -> None: self.logger = logger self.level = level self.wrap = wrap
[docs] def write(self, message: str) -> None: """ Write message to logger with optional wrapping. Args: message (str): message to log. """ if message.strip(): # only log non-empty messages if self.wrap: wrapped_message = textwrap.fill(message.strip(), width=79) else: wrapped_message = message.strip() self.logger.log(self.level, "\n" + wrapped_message + "\n")
[docs] def flush(self) -> None: """ No-op flush method for sys.stdout compatibility. Required by the file-like object interface but performs no operation for logger streams. """ pass
[docs] class JobCancelledError(Exception): """ Exception raised when user cancels a SLURM job. Triggered by keyboard interrupts (Ctrl+C) during job monitoring. """
[docs] class JobFailedError(Exception): """ Exception raised when a SLURM job fails. Indicates non-zero exit codes or failed job states detected via sacct. """
[docs] class Pipeline(ABC): """ Abstract base class for CDI data processing pipelines. Provides infrastructure for parameter management, logging, job submission (SLURM), and subprocess execution. Not intended for direct instantiation—subclass for specific applications. Args: params (dict, optional): parameter dictionary. Defaults to None. param_file_path (str, optional): path to YAML parameter file. Defaults to None. Raises: ValueError: if neither params nor param_file_path is provided. """
[docs] def __init__( self, params: dict = None, param_file_path: str = None ) -> None: """ Initialise Pipeline with parameters from dict or file. Args: params (dict, optional): parameter dictionary. Defaults to None. param_file_path (str, optional): path to YAML parameter file. Defaults to None. Raises: ValueError: if neither params nor param_file_path is provided. """ self.param_file_path = param_file_path self.params = params if params is None: if param_file_path is None: raise ValueError( "param_file_path or parameters must be provided" ) self.params = self.load_parameters() # Create the dump directory self.dump_dir = self.params["dump_dir"] self.make_dump_dir() # Initialise the logger self.logger = self._init_logger() self.interrupted = False # Flag to check for keyboard interrupt # Set the printoptions legacy to 1.21, otherwise types are printed. np.set_printoptions(legacy="1.21") # update the plot parameters update_plot_params()
[docs] def make_dump_dir(self) -> None: """ Create output directory specified in params['dump_dir']. Raises: ValueError: if dump_dir parameter is None. """ dump_dir = self.params["dump_dir"] if dump_dir is None: raise ValueError("dump_dir parameter must be set.") if os.path.isdir(dump_dir): print( "\nDump directory already exists, results will be " f"saved in:\n{dump_dir}." ) else: print(f"Creating the dump directory at: {dump_dir}") os.makedirs(dump_dir, exist_ok=True)
@staticmethod def _init_logger() -> logging.Logger: """ Initialise and configure logger for pipeline processes. Removes existing root handlers (e.g., Jupyter defaults) and sets up console logging at INFO level. Returns: logging.Logger: configured logger instance. """ # Remove all handlers associated with the root logger (Jupyter # default). for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) logger = logging.getLogger("PipelineLogger") # Check if the logger already has handlers to avoid adding # multiple. if not logger.hasHandlers(): logger.setLevel(logging.DEBUG) # Console handler console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_formatter = logging.Formatter( fmt="[%(levelname)s] %(message)s", ) console_handler.setFormatter(console_formatter) logger.addHandler(console_handler) return logger def _init_process_logger(self, process_name: str) -> logging.FileHandler: """ Initialise file handler for process-specific logging. Creates a new log file (overwriting any existing one) and attaches a file handler to the logger with DEBUG level and timestamped formatting. Args: process_name (str): base name for log file (without extension). Returns: logging.FileHandler: configured file handler attached to logger. """ file_handler = logging.FileHandler(f"{process_name}.log", mode="w") file_handler.setLevel(logging.DEBUG) file_format = logging.Formatter( fmt="%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) file_handler.setFormatter(file_format) self.logger.addHandler(file_handler) return file_handler
[docs] @staticmethod def process(func: Callable) -> Callable: """ Decorate pipeline methods to add logging and error handling. Wraps process methods with file logging, stdout redirection, and structured error reporting. Creates process-specific log files in dump_dir with format {func_name}_output.log. Args: func (Callable): pipeline method to decorate. Returns: Callable: wrapped function with logging infrastructure. Raises: Exception: re-raises any exception from decorated function after logging. Notes: Temporarily redirects sys.stdout to logger during execution to capture print statements. Original stdout is always restored in finally block. """ @wraps(func) def wrapper(self, *args, **kwargs) -> None: # Setup a new log file for this process file_handler = self._init_process_logger( f"{self.dump_dir}/{func.__name__}_output" ) msg = self.pretty_print( f"Starting process: {func.__name__}", do_print=False, return_text=True, ) self.logger.info(msg) # Redirect stdout to capture print statements in real time original_stdout = sys.stdout # Save original stdout sys.stdout = LoggerWriter(self.logger, logging.INFO) try: func(self, *args, **kwargs) self.logger.info( f"Process {func.__name__} completed successfully." ) except Exception as e: self.logger.error( f"\nError occurred in the '{func.__name__}' process:\n{e}" ) # traceback.print_exception(e) raise finally: # Restore original stdout and remove file handler sys.stdout = original_stdout self.logger.removeHandler(file_handler) file_handler.close() return wrapper
def _unwrap_logs(self) -> None: """ Disable line wrapping for logger output. Configures stdout redirection to bypass 79-character wrapping for cases requiring full-width output (e.g., tables, progress bars). """ sys.stdout = LoggerWriter(self.logger, logging.INFO, wrap=False) def _wrap_logs(self) -> None: """ Enable line wrapping for logger output. Configures stdout redirection to wrap lines at 79 characters for standard logging output. """ sys.stdout = LoggerWriter(self.logger, logging.INFO, wrap=True) def _subprocess_run( self, cmd: str | list[str] ) -> subprocess.CompletedProcess: """ Execute subprocess command with error handling. Runs command with captured stdout/stderr and validates return code. Logs errors and raises CalledProcessError on failure. Args: cmd (str | list[str]): command string or argument list. Returns: subprocess.CompletedProcess: completed process with stdout/stderr. Raises: subprocess.CalledProcessError: if command returns non-zero exit code. """ result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) if result.returncode != 0: self.logger.error(f"Command {cmd} failed: {result.stderr}") raise subprocess.CalledProcessError( result.returncode, result.args, output=result.stdout, stderr=result.stderr, ) return result
[docs] def submit_job(self, job_file: str, working_dir: str) -> tuple[str, str]: """ Submit SLURM job and return job ID with output file path. Executes sbatch command in bash login shell to ensure proper environment loading. Sets up keyboard interrupt handler for job cancellation. Args: job_file (str): path to SLURM batch script. working_dir (str): directory to execute sbatch from. Returns: tuple[str, str]: job ID and absolute path to output file (slurm-{job_id}.out). Raises: subprocess.CalledProcessError: if sbatch command fails. ValueError: if job ID cannot be extracted from sbatch output. Notes: Registers SIGINT handler that calls _handle_interrupt with job_id when Ctrl+C is pressed. """ # Set up signal handler for keyboard interrupt (Ctrl + C) signal.signal( signal.SIGINT, lambda sig, frame: self._handle_interrupt(job_id) ) cmd = f"sbatch {job_file}" try: with subprocess.Popen( ["bash", "-l", "-c", cmd], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=working_dir, # Change to this directory first text=True, # Ensures stdout/stderr are str, not bytes env=os.environ.copy(), ) as proc: stdout, stderr = proc.communicate() # Check for errors based on the return code if proc.returncode != 0: # An error occurred, log the stderr output self.logger.error( f"Error submitting job. Command returned: {stderr}" ) raise subprocess.CalledProcessError( proc.returncode, proc.args, output=stdout, stderr=stderr, ) # Extract job ID from the output job_id = self._get_job_id(stdout) if job_id: self.logger.info( f"Job submitted successfully. Job ID: {job_id}" ) output_file = f"slurm-{job_id}.out" return job_id, os.path.join(working_dir, output_file) raise ValueError( "Failed to extract job ID from sbatch output." ) except subprocess.CalledProcessError as e: # Log the error if the job submission fails self.logger.error( f"Subprocess failed with return code {e.returncode}: " f"{e.stderr}" ) raise e
@staticmethod def _get_job_id(stdout: str) -> str: """ Extract SLURM job ID from sbatch output. Parses sbatch stdout for line containing 'Submitted batch job' and returns the trailing job ID number. Args: stdout (str): sbatch command output. Returns: str: job ID as string, or None if not found. Examples: >>> _get_job_id("Submitted batch job 12345\\n") '12345' """ for line in stdout.splitlines(): if "Submitted batch job" in line: return line.split()[-1] # last element of the line return None
[docs] def is_job_running(self, job_id: str) -> bool: """ Check if SLURM job is currently running. Queries squeue for job presence. Job is considered running if its ID appears in squeue output. Args: job_id (str): SLURM job ID to check. Returns: bool: True if job is in queue, False otherwise. Raises: subprocess.CalledProcessError: if squeue command fails. """ result = self._subprocess_run(["squeue", "--job", job_id]) return job_id in result.stdout # Job is running if job_id is found
[docs] def stream_job_output(self, job_id: str, output_file: str) -> None: """ Stream SLURM job output in real-time. Waits for output file creation, then continuously reads and logs new lines until job stops running or interrupted flag is set. Logs at JOB level (custom level between INFO and WARNING). Args: job_id (str): SLURM job ID being monitored. output_file (str): path to slurm-{job_id}.out file. Raises: FileNotFoundError: if output file cannot be accessed after creation. Notes: Checks file existence every 0.5s until found. Polls running status and reads new lines with 0.5s interval. Respects self.interrupted flag for early termination. """ try: self.logger.info("Waiting for job output file...") # Wait until the output file is created (check every 2 seconds) while not os.path.exists(output_file): if self.interrupted: self.logger.info( "Job monitoring interrupted before file creation." ) return time.sleep(0.5) self.logger.info(f"Streaming job output from {output_file}:\n\n") # Keep trying to read the output file until the job is done with open(output_file, "r") as f: while not self.interrupted: # Check if the job is still in the queue before reading if not self.is_job_running(job_id): self.logger.info( f"\n\nJob {job_id} is no longer running. " "Stopping output streaming." ) break line = f.readline() if line: self.logger.job(line.strip()) else: time.sleep(0.5) # Sleep briefly before checking again except FileNotFoundError: self.logger.error(f"Output file {output_file} not found.") raise
[docs] def monitor_job( self, job_id: str, output_file: str, retries: int = 10, delay: int = 1 ) -> None: """ Monitor SLURM job and verify final completion status. Streams job output in real-time and validates final state via sacct after job leaves queue. Retries state check if job shows RUNNING but is not in squeue (handles race conditions). Args: job_id (str): SLURM job ID to monitor. output_file (str): path to slurm-{job_id}.out file. retries (int): number of sacct retries for lingering RUNNING state. Defaults to 10. delay (int): seconds between retries. Defaults to 1. Raises: JobFailedError: if job terminates with FAILED state or non-zero exit code. Notes: Successfully completed jobs have state='COMPLETED' and exit_code='0:0'. Other terminal states log a warning but do not raise exceptions. """ # Start monitoring the job and streaming output while not self.interrupted: if not self.is_job_running(job_id): self.logger.info("Checking final status...") break # Job is still running, stream the output file self.stream_job_output(job_id, output_file) # After job finishes, check final status if not self.interrupted: state, exit_code = self.get_job_state(job_id) attempt = 0 while state == "RUNNING" and attempt < retries: self.logger.info( f"Job {job_id} is still in RUNNING state but not " f"found in queue. Rechecking the state in {delay} " f"second(s)..." ) time.sleep(delay) state, exit_code = self.get_job_state(job_id) attempt += 1 if state == "COMPLETED": self.logger.info( f"Job {job_id} completed successfully with " f"exit code: {exit_code}" ) return elif state == "FAILED": raise JobFailedError( f"Job {job_id} failed with exit code: {exit_code}." f"See {output_file} for more details." ) else: self.logger.warning( f"Job {job_id} finished with unexpected state: {state}." )
[docs] def get_job_state(self, job_id: str) -> tuple[str, str]: """ Retrieve SLURM job state and exit code via sacct. Queries sacct for job status information and parses output to extract state (e.g., COMPLETED, FAILED, RUNNING) and exit code (format: signal:status). Args: job_id (str): SLURM job ID to query. Returns: tuple[str, str]: job state and exit code (e.g., ('COMPLETED', '0:0')). Raises: ValueError: if job ID not found in sacct output. subprocess.CalledProcessError: if sacct command fails. """ result = self._subprocess_run( [ "sacct", "-j", job_id, "--format=JobID,State,ExitCode", "--noheader", ] ) state, exit_code = None, None # Parse the sacct output to check the job's final state for line in result.stdout.splitlines(): if job_id in line: parts = line.split() if len(parts) >= 3: state = parts[1] exit_code = parts[2] break if state is None or exit_code is None: raise ValueError(f"Job {job_id} not found in sacct.") return state, exit_code
[docs] def cancel_job(self, job_id: str) -> None: """ Cancel running SLURM job via scancel. Args: job_id (str): SLURM job ID to cancel. Raises: subprocess.CalledProcessError: if scancel command fails. """ try: self.logger.info(f"\n\nCancelling job {job_id}...") _ = self._subprocess_run(["scancel", job_id]) self.logger.info(f"Job {job_id} cancelled successfully.") except subprocess.CalledProcessError as e: self.logger.error(f"Failed to cancel job {job_id}: {e.stderr}") raise
def _handle_interrupt(self, job_id: str) -> None: """ Handle keyboard interrupt by cancelling job. Sets interrupted flag, cancels SLURM job via scancel, and raises JobCancelledError to terminate monitoring. Args: job_id (str): SLURM job ID to cancel. Raises: JobCancelledError: always raised after job cancellation. Notes: Called by SIGINT handler registered in submit_job. Sets self.interrupted=True to signal monitoring loops. """ self.interrupted = True # Set flag to interrupt monitoring self.cancel_job(job_id) raise JobCancelledError( f"Keyboard interruption. Job {job_id} was cancelled by the user." )
[docs] def load_parameters(self, file_path: str = None) -> dict: """ Load pipeline parameters from YAML configuration file. Uses yaml.full_load() to support Python-specific types like tuples that are serialised by yaml.dump(). Args: file_path (str, optional): path to YAML parameter file. Defaults to None (uses self.param_file_path). Returns: dict: loaded parameter dictionary. Raises: FileNotFoundError: if parameter file does not exist. yaml.YAMLError: if file contains invalid YAML. """ if file_path is None: file_path = self.param_file_path with open(file_path, "r", encoding="utf8") as file: params = yaml.full_load(file) return params
[docs] @staticmethod def pretty_print( text: str, max_char_per_line: int = 79, do_print: bool = True, return_text: bool = False, ) -> None | str: """ Format text with decorative star border. Creates a framed message with star borders and centred text wrapped to specified line width. Useful for logging section headers or important messages. Args: text (str): text to format. max_char_per_line (int): maximum line width including border. Defaults to 79. do_print (bool): whether to print formatted text. Defaults to True. return_text (bool): whether to return formatted string. Defaults to False. Returns: None | str: formatted text if return_text=True, else None. Examples: >>> pretty_print("Hello World", max_char_per_line=30) ****************************** * Hello World * ****************************** """ pretty_text = "\n".join( [ "", "*" * (max_char_per_line), *[ f"* {w[::-1].center(max_char_per_line - 4)[::-1]} *" for w in textwrap.wrap(text, width=max_char_per_line - 4) ], "*" * max_char_per_line, "", ] ) if do_print: print(pretty_text) if return_text: return pretty_text return None