Walkthrough
Overview
This section explains how to create and launch a preempt-able wandb sweep on a SLURM cluster. You have to set up three files:
-
A training script (e.g.
train.py) that the sweep will execute multiple times using different hyper-parameters. -
A
wandbsweep configuration script (e.g.sweep.yaml) that defines the hyper-parameter search space. -
A SLURM launch script (e.g.
launch.sh) which we will use to submit runs on the cluster.
The repository's example directory contains examples for each of these files, and here we will demonstrate how to make them work together. We will operate inside the example directory, so let's navigate to it:
# If you haven't already, you'll need to clone the repository first
git clone git@github.com:f-dangel/wandb_preempt.git && cd wandb_preempt
# And pip install the package from the repository
pip install -e .[example]
# Then navigate to the example directory within the repo
cd example
Training Script
First up, we need to write a training script that we will sweep over. The sweep will call this script with different hyper-parameters to find the hyper-parameters that work best.
For demonstration purposes, we will train a small CNN on MNIST using SGD, and our goal is to find a good learning rate through random search using a wandb sweep. To keep things simple and cheap, we fix a batch size and use a (very) small number of epochs. Finally, we also use a learning rate scheduler and mixed-precision training with a gradient scaler. These are overkill for MNIST, of course, but important when training large models so we include them here to show how they are checkpointed too. In summary, we will call the training script using the following pattern:
X being some floating point number that the sweep will search over.
The training script also contains code to checkpoint the training loop at the end of an epoch, so we can resume training after our pre-emptable job is interrupted. Roughly speaking, we need to create a Checkpointer that is responsible for saving and storing checkpoints, listening to signals from the Slurm process, and for requeuing the job on the cluster if it runs out of time before it is finished. Every time we call the checkpointer's .step method, it will save a checkpoint and check whether a signal has been sent by Slurm that indicates our job is about to be killed and so we must preemptively halt it and requeue the job.
For more details, please expand the code snippet below.
Details of the training script example/train.py (source)
#!/usr/bin/env python
"""Train a simple CNN on MNIST using checkpoints, integrated with Weights & Biases.
The changes required to integrate checkpointing with wandb are tagged with 'NOTE'.
"""
from argparse import ArgumentParser
import wandb
from torch import autocast, bfloat16, cuda, device, manual_seed
from torch.cuda.amp import GradScaler
from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, ReLU, Sequential
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from wandb_preempt.checkpointer import Checkpointer
LOGGING_INTERVAL = 50 # Num batches between logging to stdout and wandb
VERBOSE = True # Enable verbose output
def get_parser():
r"""Create argument parser."""
parser = ArgumentParser("Train a simple CNN on MNIST using SGD.")
parser.add_argument(
"--lr_max", type=float, default=0.01, help="Learning rate. Default: %(default)s"
)
parser.add_argument(
"--epochs", type=int, default=20, help="Number of epochs. Default: %(default)s"
)
parser.add_argument(
"--batch_size", type=int, default=256, help="Batch size. Default: %(default)s"
)
parser.add_argument(
"--checkpoint_dir", type=str, default="checkpoints", help="Checkpoint save dir."
)
return parser
def main(args):
r"""Train model."""
manual_seed(0) # make deterministic
DEV = device("cuda" if cuda.is_available() else "cpu")
# NOTE: Allow runs to resume by passing 'allow' to wandb
run = wandb.init(resume="allow")
# Set up the data, neural net, loss function, and optimizer
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
train_loader = DataLoader(
dataset=train_dataset, batch_size=args.batch_size, shuffle=True
)
model = Sequential(
Conv2d(1, 3, kernel_size=5, stride=2),
ReLU(),
Flatten(),
Linear(432, 50),
ReLU(),
Linear(50, 10),
).to(DEV)
loss_func = CrossEntropyLoss().to(DEV)
print(f"Using SGD with learning rate {args.lr_max}.")
optimizer = SGD(model.parameters(), lr=args.lr_max)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
scaler = GradScaler()
# NOTE: Set up a check-pointer which will load and save checkpoints.
# Pass the run ID to obtain unique file names for the checkpoints.
checkpointer = Checkpointer(
run.id,
model,
optimizer,
lr_scheduler=lr_scheduler,
scaler=scaler,
savedir=args.checkpoint_dir,
verbose=VERBOSE,
)
# NOTE: If existing, load model, optimizer, and learning rate scheduler state from
# latest checkpoint, set random number generator states. If there was no checkpoint
# to load, it does nothing and returns `None` for the step count.
checkpoint_index, _ = checkpointer.load_latest_checkpoint()
# Select the remaining epochs to train
start_epoch = 0 if checkpoint_index is None else checkpoint_index + 1
# training
for epoch in range(start_epoch, args.epochs):
model.train()
for step, (inputs, target) in enumerate(train_loader):
optimizer.zero_grad()
with autocast(device_type="cuda", dtype=bfloat16):
output = model(inputs.to(DEV))
loss = loss_func(output, target.to(DEV))
if step % LOGGING_INTERVAL == 0:
print(f"Epoch {epoch}, Step {step}, Loss {loss.item():.5e}")
wandb.log(
{
"loss": loss.item(),
"lr": optimizer.param_groups[0]["lr"],
"loss_scale": scaler.get_scale(),
"epoch": epoch,
"resumes": checkpointer.num_resumes,
}
)
scaler.scale(loss).backward()
scaler.step(optimizer) # update neural network parameters
scaler.update() # update the gradient scaler
lr_scheduler.step() # update learning rate
# NOTE Put validation code here
# eval(model, ...)
# NOTE Call checkpointer.step() at the end of the epoch to save a
# checkpoint. If SLURM sent us a signal that our time for this job is
# running out, it will now also take care of pre-empting the wandb job
# and requeuing the SLURM job, killing the current python training script
# to resume with the requeued job.
checkpointer.step()
wandb.finish()
# NOTE Remove all created checkpoints once we are done training. If you want to
# keep the trained model, remove this line.
checkpointer.remove_checkpoints()
if __name__ == "__main__":
# Run as a script
parser = get_parser()
args = parser.parse_args()
main(args)
Sweep Configuration
Our next goal will be to define and create a sweep.
For that, we need to write a .yaml file which specifies how the training script is called, and what the search space looks like. To learn more, take a look at the Weights & Biases documentation.
The following configuration file defines a random search over the learning rate, using a log-uniform density for the search space.
By default, the example config will create a new project called example-preemptable-sweep owned by your default wandb entity (controlled by your Default team setting).
Details of the sweep configuration example/sweep.yaml (source)
Let's create a sweep using this configuration:
Output
Each sweep has its own ID (you can have more than one sweep in the same project), so to launch jobs in this sweep that we've just recreated, we'll need to launch them using the correct sweep ID. To do this, note the command in the last line of the outputโcopy this to use later.
Navigate to the wandb web interface and you should be able to see the sweep now exists:

Optional Step: Local Run
To make sure the configuration file works, I will execute a single run locally on my machine as a sanity check. This step is optional and obviously not recommended if your machine's hardware is not beefy enough (note that I use the command from above, but add the --count=1 flag to carry out only a single run):
Training script output
wandb: Starting wandb agent ๐ต๏ธ
2024-09-11 15:50:13,390 - wandb.wandb_agent - INFO - Running runs: []
2024-09-11 15:50:13,596 - wandb.wandb_agent - INFO - Agent received command: run
2024-09-11 15:50:13,596 - wandb.wandb_agent - INFO - Agent starting run with config:
epochs: 20
lr_max: 0.002046505897436452
2024-09-11 15:50:13,597 - wandb.wandb_agent - INFO - About to run command: /usr/bin/env python train.py --epochs=20 --lr_max=0.002046505897436452
wandb: Currently logged in as: f-dangel (f-dangel-team). Use `wandb login --relogin` to force relogin
2024-09-11 15:50:18,609 - wandb.wandb_agent - INFO - Running runs: ['2zoz0rl8']
wandb: Tracking run with wandb version 0.17.9
wandb: Run data is saved locally in ~/wandb_preempt/example/wandb/run-20240911_155016-2zoz0rl8
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run unique-sweep-1
wandb: โญ๏ธ View project at https://wandb.ai/f-dangel-team/example-preemptable-sweep
wandb: ๐งน View sweep at https://wandb.ai/f-dangel-team/example-preemptable-sweep/sweeps/4m89qo6r
wandb: ๐ View run at https://wandb.ai/f-dangel-team/example-preemptable-sweep/runs/2zoz0rl8
Using SGD with learning rate 0.002046505897436452.
[0.0 s | 2024-09-11 15:50:19.844072] Creating checkpoint directory: ~/wandb_preempt/example/checkpoints.
[0.0 s | 2024-09-11 15:50:19.844166] SLURM job ID: None, array ID: None, task ID: None
[0.0 s | 2024-09-11 15:50:19.844522] No checkpoint found. Starting from scratch.
Epoch 0, Step 0, Loss 2.31327e+00
Epoch 0, Step 50, Loss 2.31634e+00
Epoch 0, Step 100, Loss 2.30524e+00
Epoch 0, Step 150, Loss 2.29960e+00
Epoch 0, Step 200, Loss 2.30427e+00
[3.0 s | 2024-09-11 15:50:22.801220] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000000.pt.
Epoch 1, Step 0, Loss 2.29656e+00
Epoch 1, Step 50, Loss 2.30689e+00
Epoch 1, Step 100, Loss 2.29662e+00
Epoch 1, Step 150, Loss 2.28267e+00
Epoch 1, Step 200, Loss 2.30082e+00
[5.7 s | 2024-09-11 15:50:25.576881] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000001.pt.
[5.7 s | 2024-09-11 15:50:25.578131] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000000.pt.
Epoch 2, Step 0, Loss 2.28091e+00
Epoch 2, Step 50, Loss 2.27995e+00
Epoch 2, Step 100, Loss 2.27886e+00
Epoch 2, Step 150, Loss 2.27613e+00
Epoch 2, Step 200, Loss 2.27744e+00
[8.5 s | 2024-09-11 15:50:28.328779] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000002.pt.
[8.5 s | 2024-09-11 15:50:28.329942] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000001.pt.
Epoch 3, Step 0, Loss 2.27019e+00
Epoch 3, Step 50, Loss 2.27712e+00
Epoch 3, Step 100, Loss 2.26028e+00
Epoch 3, Step 150, Loss 2.25132e+00
Epoch 3, Step 200, Loss 2.25152e+00
[11.2 s | 2024-09-11 15:50:31.057194] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000003.pt.
[11.2 s | 2024-09-11 15:50:31.058385] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000002.pt.
Epoch 4, Step 0, Loss 2.23886e+00
Epoch 4, Step 50, Loss 2.24897e+00
Epoch 4, Step 100, Loss 2.23878e+00
Epoch 4, Step 150, Loss 2.21464e+00
Epoch 4, Step 200, Loss 2.21080e+00
[14.0 s | 2024-09-11 15:50:33.822246] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000004.pt.
[14.0 s | 2024-09-11 15:50:33.823408] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000003.pt.
Epoch 5, Step 0, Loss 2.19485e+00
Epoch 5, Step 50, Loss 2.19484e+00
Epoch 5, Step 100, Loss 2.16891e+00
Epoch 5, Step 150, Loss 2.16754e+00
Epoch 5, Step 200, Loss 2.13477e+00
[16.7 s | 2024-09-11 15:50:36.518798] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000005.pt.
[16.7 s | 2024-09-11 15:50:36.519988] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000004.pt.
Epoch 6, Step 0, Loss 2.12859e+00
Epoch 6, Step 50, Loss 2.10682e+00
Epoch 6, Step 100, Loss 2.09931e+00
Epoch 6, Step 150, Loss 2.08149e+00
Epoch 6, Step 200, Loss 2.04833e+00
[19.4 s | 2024-09-11 15:50:39.239762] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000006.pt.
[19.4 s | 2024-09-11 15:50:39.240944] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000005.pt.
Epoch 7, Step 0, Loss 2.02058e+00
Epoch 7, Step 50, Loss 1.97293e+00
Epoch 7, Step 100, Loss 1.94745e+00
Epoch 7, Step 150, Loss 1.90756e+00
Epoch 7, Step 200, Loss 1.89235e+00
[22.1 s | 2024-09-11 15:50:41.981751] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000007.pt.
[22.1 s | 2024-09-11 15:50:41.983563] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000006.pt.
Epoch 8, Step 0, Loss 1.82919e+00
Epoch 8, Step 50, Loss 1.80327e+00
Epoch 8, Step 100, Loss 1.74424e+00
Epoch 8, Step 150, Loss 1.68607e+00
Epoch 8, Step 200, Loss 1.67496e+00
[25.0 s | 2024-09-11 15:50:44.813117] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000008.pt.
[25.0 s | 2024-09-11 15:50:44.814475] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000007.pt.
Epoch 9, Step 0, Loss 1.62010e+00
Epoch 9, Step 50, Loss 1.56824e+00
Epoch 9, Step 100, Loss 1.50516e+00
Epoch 9, Step 150, Loss 1.48588e+00
Epoch 9, Step 200, Loss 1.44233e+00
[27.8 s | 2024-09-11 15:50:47.612270] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000009.pt.
[27.8 s | 2024-09-11 15:50:47.613580] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000008.pt.
Epoch 10, Step 0, Loss 1.37147e+00
Epoch 10, Step 50, Loss 1.34652e+00
Epoch 10, Step 100, Loss 1.31548e+00
Epoch 10, Step 150, Loss 1.31214e+00
Epoch 10, Step 200, Loss 1.31763e+00
[30.5 s | 2024-09-11 15:50:50.342514] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000010.pt.
[30.5 s | 2024-09-11 15:50:50.343644] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000009.pt.
Epoch 11, Step 0, Loss 1.18105e+00
Epoch 11, Step 50, Loss 1.18585e+00
Epoch 11, Step 100, Loss 1.10869e+00
Epoch 11, Step 150, Loss 1.08874e+00
Epoch 11, Step 200, Loss 1.06454e+00
[33.4 s | 2024-09-11 15:50:53.222490] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000011.pt.
[33.4 s | 2024-09-11 15:50:53.223652] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000010.pt.
Epoch 12, Step 0, Loss 1.13357e+00
Epoch 12, Step 50, Loss 9.96835e-01
Epoch 12, Step 100, Loss 1.06371e+00
Epoch 12, Step 150, Loss 9.63902e-01
Epoch 12, Step 200, Loss 9.63633e-01
[36.1 s | 2024-09-11 15:50:55.968639] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000012.pt.
[36.1 s | 2024-09-11 15:50:55.970050] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000011.pt.
Epoch 13, Step 0, Loss 9.34712e-01
Epoch 13, Step 50, Loss 8.95310e-01
Epoch 13, Step 100, Loss 9.12703e-01
Epoch 13, Step 150, Loss 9.39363e-01
Epoch 13, Step 200, Loss 9.21194e-01
[38.9 s | 2024-09-11 15:50:58.722370] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000013.pt.
[38.9 s | 2024-09-11 15:50:58.723519] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000012.pt.
Epoch 14, Step 0, Loss 8.49049e-01
Epoch 14, Step 50, Loss 9.19110e-01
Epoch 14, Step 100, Loss 8.95127e-01
Epoch 14, Step 150, Loss 8.37601e-01
Epoch 14, Step 200, Loss 9.13763e-01
[41.6 s | 2024-09-11 15:51:01.492518] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000014.pt.
[41.6 s | 2024-09-11 15:51:01.493670] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000013.pt.
Epoch 15, Step 0, Loss 8.37812e-01
Epoch 15, Step 50, Loss 9.73370e-01
Epoch 15, Step 100, Loss 7.91447e-01
Epoch 15, Step 150, Loss 8.27363e-01
Epoch 15, Step 200, Loss 8.46579e-01
[44.4 s | 2024-09-11 15:51:04.212638] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000015.pt.
[44.4 s | 2024-09-11 15:51:04.213888] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000014.pt.
Epoch 16, Step 0, Loss 8.59434e-01
Epoch 16, Step 50, Loss 9.20763e-01
Epoch 16, Step 100, Loss 7.62155e-01
Epoch 16, Step 150, Loss 7.71248e-01
Epoch 16, Step 200, Loss 8.11831e-01
[47.1 s | 2024-09-11 15:51:06.972560] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000016.pt.
[47.1 s | 2024-09-11 15:51:06.973763] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000015.pt.
Epoch 17, Step 0, Loss 8.56807e-01
Epoch 17, Step 50, Loss 8.06021e-01
Epoch 17, Step 100, Loss 8.36283e-01
Epoch 17, Step 150, Loss 7.88259e-01
Epoch 17, Step 200, Loss 8.26321e-01
[49.9 s | 2024-09-11 15:51:09.694183] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000017.pt.
[49.9 s | 2024-09-11 15:51:09.695488] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000016.pt.
Epoch 18, Step 0, Loss 7.45168e-01
Epoch 18, Step 50, Loss 7.74083e-01
Epoch 18, Step 100, Loss 8.39497e-01
Epoch 18, Step 150, Loss 7.77645e-01
Epoch 18, Step 200, Loss 8.34373e-01
[52.6 s | 2024-09-11 15:51:12.442971] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000018.pt.
[52.6 s | 2024-09-11 15:51:12.444134] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000017.pt.
Epoch 19, Step 0, Loss 7.31586e-01
Epoch 19, Step 50, Loss 7.85134e-01
Epoch 19, Step 100, Loss 7.31892e-01
Epoch 19, Step 150, Loss 7.79394e-01
Epoch 19, Step 200, Loss 7.37044e-01
[55.3 s | 2024-09-11 15:51:15.190555] Saving checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000019.pt.
[55.3 s | 2024-09-11 15:51:15.191727] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000018.pt.
wandb:
wandb:
wandb: Run history:
wandb: epoch โโโโโโโโโโโโโโโโโโโโโ
โ
โ
โ
โ
โ
โโโโโโโโโโโโโโ
wandb: loss โโโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโโ
wandb: loss_scale โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
wandb: lr โโโโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโ
wandb: resumes โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
wandb:
wandb: Run summary:
wandb: epoch 19
wandb: loss 0.73704
wandb: loss_scale 1.0
wandb: lr 1e-05
wandb: resumes 0
wandb:
wandb: ๐ View run unique-sweep-1 at: https://wandb.ai/f-dangel-team/example-preemptable-sweep/runs/2zoz0rl8
wandb: โญ๏ธ View project at: https://wandb.ai/f-dangel-team/example-preemptable-sweep
wandb: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240911_155016-2zoz0rl8/logs
wandb: WARNING The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See https://wandb.me/wandb-core for more information.
[60.5 s | 2024-09-11 15:51:20.373116] Removing checkpoint ~/wandb_preempt/example/checkpoints/2024-09-11/2zoz0rl8_00000019.pt.
2024-09-11 15:51:24,319 - wandb.wandb_agent - INFO - Cleaning up finished run: 2zoz0rl8
wandb: Terminating and syncing runs. Press ctrl-c to kill.
On the Weights & Biases web API, we can see the successfully finished run:

SLURM Launcher
The last step is to launch multiple jobs on a SLURM cluster.
For that, we use the following launch script. If you are running this yourself, you will need to modify the wandb agent command to be the one we copied before, when the wandb sweep was created. Note that we still include the --count=1 argument, to ensure each Slurm job in the array completes a single task from the sweep.
We will explain the script in more detail below; for now our focus is to launch jobs.
Details of the SLURM script example/launch.sh (source)
#!/bin/bash
#SBATCH --partition=a40
#SBATCH --nodes=1
#SBATCH --tasks-per-node=1
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-gpu=4
#SBATCH --mem-per-gpu=16G
#SBATCH --qos=m5
#SBATCH --open-mode=append
#SBATCH --time=00:04:00
#SBATCH --array=0-9
#SBATCH --signal=B:SIGUSR1@120 # Send signal SIGUSR1 120 seconds before the job hits the time limit
echo "Job $SLURM_JOB_NAME ($SLURM_JOB_ID) begins on $(hostname), submitted from $SLURM_SUBMIT_HOST ($SLURM_CLUSTER_NAME)"
echo ""
# wait for a specific time to avoid simultaneous API requests from multiple agents
if [ "$SLURM_ARRAY_TASK_COUNT" != "" ]; then
sleep $((5 * ( SLURM_ARRAY_TASK_ID - SLURM_ARRAY_TASK_MIN) ))
fi
# NOTE that we need to use srun here, otherwise the Python process won't receive the SIGUSR1 signal
srun wandb agent --count=1 f-dangel-team/example-preemptable-sweep/4m89qo6r &
child="$!"
# Set up a handler to pass the SIGUSR1 to the python session launched by the agent
function term_handler()
{
echo "$(date) ** Job $SLURM_JOB_NAME ($SLURM_JOB_ID) received SIGUSR1 **"
# The Checkpointer will have written the PID of the Python process to a file
# so we can send it the SIGUSR1 signal
PID=$(cat "${SLURM_JOB_ID}.pid")
echo "$(date) ** Sending kill signal to python process $PID **"
# Send the signal multiple times because it may not be caught if the Python
# process happens to be in the middle of writing a checkpoint. The while loop
# exits when `kill` errors, which happens when the python process has exited.
while kill -SIGUSR1 "$PID" 2>/dev/null
do
echo "$(date) Sent SIGUSR1 signal to python"
sleep 10
done
}
# Call this term_handler function when the job recieves the SIGUSR1 or SIGTERM signal
# SIGUSR1 is sent by SLURM 120s before the time limit, thanks to the SBATCH --signal=...
# setting in the header.
# SIGTERM is sent shortly* before the job is killed (*with interval between the signal
# and being properly killed depending on SLURM cluster's `GraceTime` value)
trap term_handler SIGUSR1
trap term_handler SIGTERM # NOTE we trap SIGTERM but send SIGUSR1 to the Python process
# The srun command is running in the background, and we need to wait for it to finish.
# The wait command here is in the foreground and so it will be interrupted by the trap
# handler when we receive a SIGUSR1 or SIGTERM signal.
wait "$child"
# Clean up the pid file
rm "${SLURM_JOB_ID}.pid"
echo "$(date) Reached EOF"
Log into your SLURM cluster, then navigate to the example directory and submit jobs to SLURM:
Use watch squeue --me to monitor the job queue. You will observe that the jobs will launch and run for a short amount of time before receiving the pre-emption signal from SLURM. After that, they will requeue themselves and pick up from the latest checkpoint, until training is completely finished.
On the Weights & Biases website, you will see the runs transitioning between the states 'Running', 'Preempted', and 'Finished'. Here is an example view:

After full completion, the 'Workspace' tab on Weights & Biases looks as follows (of course, your curves will look slightly different due to sweep's and compute environment's stochastic nature):

The resume panel shows that different runs pre-empted a different number of times.
In total, we now have 1 (local) + 10 (SLURM) = 11 (total) finished runs.
Conclusion
And that is pretty much it. Feel free to stop reading at this point.
SLURM Launcher Details
The launch script divides into three parts:
-
SLURM configuration: The first block specifies the SLURM resources and task array we are about to submit (lines starting with
#SBATCH). The important configurations are--time(how long to request the job will run for?)--array(how many jobs will be submitted?), and--signalspecifications (how much time before the limit will we start pre-empting?)
These values are optimized for demonstration purposes. You definitely want to tweak them for your use case.
Note that the
--timerequest does not need to be the total amount of time the model takes to train, since thecheckpointerwill automatically requeue the job if/when time limit is about to be reached and the training script is still running.The
--signalargument is used to tell the python script that the Slurm job is about to end (using the signal SIGUSR1). The time in the--signalargument should be set to (at least) the amount of time (in seconds) between calls tocheckpointer.step()in your training script. In the example script, we use an epoch-based training routine, and so the time specified in the signal argument needs to be the amount of time taken to complete one training epoch and save the model, or longer.The other configuration flags are resource-specific, and some like the
--partition=a40will depend on your cluster. In our script, we request NVIDIA A40 GPUs because they support mixed-precision training withbfloat16that is used by many modern training pipelines. -
Printing details and wandb agent launch: This part executes the
wandb agenton our sweep and puts it into the background so our launch script can start listening to signals from SLURM. -
Installing a trap handler: This function will process the signal sent by SLURM and pass it on to the python process, which will then initiate pre-emption and requeueing.