Skip to content

API Documentation

wandb_preempt.Checkpointer

Checkpointer(run_id: str, model: Module, optimizer: Union[Optimizer, None], lr_scheduler: Optional[LRScheduler] = None, scaler: Optional[GradScaler] = None, savedir: str = 'checkpoints', verbose: bool = False)

Class for storing, loading, and removing checkpoints.

Can be marked as pre-empted by sending a SIGUSR1 signal to a Python session.

How to use this class:

  • Create an instance of this class checkpointer = Checkpointer(...).
  • At the end of each epoch, call checkpointer.step() to save a checkpoint. If the job received the SIGUSR1 or SIGTERM signal, the checkpointer will requeue the Slurm job at the end of its checkpointing step.

Set up a checkpointer.

Parameters:

  • run_id (str) –

    A unique identifier for this run.

  • model (Module) –

    The model that is trained and checkpointed.

  • optimizer (Union[Optimizer, None]) –

    The optimizer that is used for training and should be checkpointed. Use None to explicitly ignore the optimizer. This can be useful if your optimizer does not implement .state_dict and .load_state_dict.

  • lr_scheduler (Optional[LRScheduler], default: None ) –

    The learning rate scheduler that is used for training. If None, no learning rate scheduler is assumed. Default: None.

  • scaler (Optional[GradScaler], default: None ) –

    The gradient scaler that is used when training in mixed precision. If None, no gradient scaler is assumed. Default: None.

  • savedir (str, default: 'checkpoints' ) –

    Directory to store checkpoints in. Default: 'checkpoints'.

  • verbose (bool, default: False ) –

    Whether to print messages about saving and loading checkpoints. Default: False

Source code in wandb_preempt/checkpointer.py
def __init__(
    self,
    run_id: str,
    model: Module,
    optimizer: Union[Optimizer, None],
    lr_scheduler: Optional[LRScheduler] = None,
    scaler: Optional[GradScaler] = None,
    savedir: str = "checkpoints",
    verbose: bool = False,
) -> None:
    """Set up a checkpointer.

    Args:
        run_id: A unique identifier for this run.
        model: The model that is trained and checkpointed.
        optimizer: The optimizer that is used for training and should be
            checkpointed. Use `None` to explicitly ignore the optimizer. This can
            be useful if your optimizer does not implement `.state_dict` and
            `.load_state_dict`.
        lr_scheduler: The learning rate scheduler that is used for training. If
            `None`, no learning rate scheduler is assumed. Default: `None`.
        scaler: The gradient scaler that is used when training in mixed precision.
            If `None`, no gradient scaler is assumed. Default: `None`.
        savedir: Directory to store checkpoints in. Default: `'checkpoints'`.
        verbose: Whether to print messages about saving and loading checkpoints.
            Default: `False`
    """
    self.time_created = time()
    self.run_id = run_id
    self.model = model
    self.optimizer = optimizer
    self.lr_scheduler = lr_scheduler
    self.scaler = scaler
    self.verbose = verbose
    self.marked_preempted = False
    self.step_count = 0
    self.num_resumes = 0

    # Set up signal handler listening for SIGUSR1, when we receive this signal,
    # we mark the job as about to be pre-empted.
    # Similarly, try to gracefully end if we receive the SIGTERM signal.
    signal(SIGUSR1, self.mark_preempted)
    signal(SIGTERM, self.mark_preempted)

    self.savedir = path.abspath(savedir)
    self.maybe_print(f"Creating checkpoint directory: {self.savedir}.")
    makedirs(self.savedir, exist_ok=True)

    # Detect whether we are running inside a SLURM session
    job_id = getenv("SLURM_JOB_ID")
    array_id = getenv("SLURM_ARRAY_JOB_ID")
    task_id = getenv("SLURM_ARRAY_TASK_ID")
    self.maybe_print(
        f"SLURM job ID: {job_id}, array ID: {array_id}, task ID: {task_id}"
    )
    self.uses_slurm = any(var is not None for var in {job_id, array_id, task_id})

    # We will create sub-folders in the directory supplied by the user where
    # checkpoints are stored. If we are on SLURM, we will use the `SLURM_JOB_ID`
    # variable as name, otherwise we will use the formatted day.
    self.savedir_job = path.join(
        self.savedir,
        f"{environ['SLURM_JOB_ID'] if self.uses_slurm else date.today()}",
    )

    # write Python PID to a file so it can be read by the signal handler from the
    # sbatch script, because it has to send a kill signal with SIGUSR1 to that PID.
    if self.uses_slurm:
        filename = f"{job_id}.pid"
        pid = str(getpid())
        self.maybe_print(f"Writing PID {pid} to file {filename}.")
        with open(filename, "w") as f:
            f.write(pid)

load_latest_checkpoint

load_latest_checkpoint(weights_only: bool = True, **kwargs) -> Tuple[Union[int, None], Dict]

Load the latest checkpoint and set random number generator states.

Updates the model, optimizer, lr scheduler, and gradient scaler states passed at initialization.

Parameters:

  • weights_only (bool, default: True ) –

    Whether to only unpickle objects that are safe to unpickle. If True, the only types that will be loaded are tensors, primitive types, dictionaries and types added via torch.serialization.add_safe_globals(). See torch.load for more information. Default: True.

  • **kwargs

    Additional keyword arguments to pass to the torch.load function.

Returns:

  • loaded_step ( Union[int, None] ) –

    The index of the checkpoint that was loaded, or None if no checkpoint was found.

  • extra_info ( Dict ) –

    Extra information that was passed by the user to the step function when the checkpoint was saved, or an empty dictionary if there is no extra information.

Source code in wandb_preempt/checkpointer.py
def load_latest_checkpoint(
    self, weights_only: bool = True, **kwargs
) -> Tuple[Union[int, None], Dict]:
    """Load the latest checkpoint and set random number generator states.

    Updates the model, optimizer, lr scheduler, and gradient scaler states
    passed at initialization.

    Args:
        weights_only: Whether to only unpickle objects that are safe to unpickle.
            If `True`, the only types that will be loaded are tensors, primitive
            types, dictionaries and types added via
            [`torch.serialization.add_safe_globals()`](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).
            See
            [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html)
            for more information.
            Default: `True`.
        **kwargs: Additional keyword arguments to pass to the
            [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html)
            function.

    Returns:
        loaded_step: The index of the checkpoint that was loaded, or `None` if no
            checkpoint was found.
        extra_info: Extra information that was passed by the user to the `step`
            function when the checkpoint was saved, or an empty dictionary if there
            is no extra information.
    """
    loadpath = self.latest_checkpoint()
    if loadpath is None:
        self.maybe_print("No checkpoint found. Starting from scratch.")
        return None, {}

    self.maybe_print(f"Loading checkpoint {loadpath}.")

    data = load(loadpath, weights_only=weights_only, **kwargs)
    self.maybe_print("Loading model.")
    self.model.load_state_dict(data["model"])
    if self.optimizer is not None:
        self.maybe_print("Loading optimizer.")
        self.optimizer.load_state_dict(data["optimizer"])
    if self.lr_scheduler is not None:
        self.maybe_print("Loading lr scheduler.")
        self.lr_scheduler.load_state_dict(data["lr_scheduler"])
    if self.scaler is not None:
        self.maybe_print("Loading gradient scaler.")
        self.scaler.load_state_dict(data["scaler"])

    self.step_count = data["checkpoint_step"] + 1
    self.num_resumes = data["resumes"] + 1

    # restore random number generator states for all devices
    self.maybe_print("Setting RNG states.")
    for dev, rng_state in data["rng_states"].items():
        if "cuda" in dev:
            cuda.set_rng_state(rng_state, dev)
        else:
            set_rng_state(rng_state)

    # N.B. We return the checkpoint step index of the saved file that was loaded,
    # but the checkpointer.step_count is one larger than that because we increment
    # it after saving - it tracks the index of the next checkpoint to be saved.
    return data["checkpoint_step"], data["extra_info"]

step

step(extra_info: Optional[Dict] = None)

Perform a checkpointing step.

Save the checkpoint. If we were pre-empted we requeue the job and exit the training script after saving.

Parameters:

  • extra_info (Optional[Dict], default: None ) –

    Additional information to save in the checkpoint. This dictionary is returned when loading the latest checkpoint with Checkpointer.load_latest_checkpoint. By default, an empty dictionary is saved.

Source code in wandb_preempt/checkpointer.py
def step(self, extra_info: Optional[Dict] = None):
    """Perform a checkpointing step.

    Save the checkpoint. If we were pre-empted we requeue the job
    and exit the training script after saving.

    Args:
        extra_info: Additional information to save in the checkpoint. This
            dictionary is returned when loading the latest checkpoint with
            [`Checkpointer.load_latest_checkpoint`](../api/#wandb_preempt.Checkpointer.load_latest_checkpoint).
            By default, an empty dictionary is saved.
    """
    self.save_checkpoint({} if extra_info is None else extra_info)
    # Remove stale checkpoints
    self.remove_checkpoints(keep_latest=True)

    # requeue the job if the run was marked as pre-empted and exit
    if self.marked_preempted:
        self.maybe_print("Run was marked as pre-empted via signal.")
        self.preempt_wandb_run()
        self.maybe_requeue_slurm_job()
        self.maybe_print("Exiting with error code 1.")
        exit(1)
    # Increase the number of steps taken
    self.step_count += 1

remove_checkpoints

remove_checkpoints(keep_latest: bool = False)

Remove checkpoints.

Parameters:

  • keep_latest (bool, default: False ) –

    Whether to keep the latest checkpoint. Default: False.

Raises:

  • RuntimeError

    If a non-.pt file is found in the checkpoint directory.

Source code in wandb_preempt/checkpointer.py
def remove_checkpoints(self, keep_latest: bool = False):
    """Remove checkpoints.

    Args:
        keep_latest: Whether to keep the latest checkpoint. Default: `False`.

    Raises:
        RuntimeError: If a non-`.pt` file is found in the checkpoint directory.
    """
    checkpoints = self.old_checkpoints() if keep_latest else self.all_checkpoints()
    for checkpoint in checkpoints:
        if not checkpoint.endswith(".pt"):
            raise RuntimeError(f"Was asked to delete a non-.pt-file: {checkpoint}.")
        self.maybe_print(f"Removing checkpoint {checkpoint}.")
        remove(checkpoint)