API Documentation
wandb_preempt.Checkpointer
Checkpointer(run_id: str, model: Module, optimizer: Union[Optimizer, None], lr_scheduler: Optional[LRScheduler] = None, scaler: Optional[GradScaler] = None, savedir: str = 'checkpoints', verbose: bool = False)
Class for storing, loading, and removing checkpoints.
Can be marked as pre-empted by sending a SIGUSR1 signal to a Python session.
How to use this class:
- Create an instance of this class
checkpointer = Checkpointer(...). - At the end of each epoch, call
checkpointer.step()to save a checkpoint. If the job received theSIGUSR1orSIGTERMsignal, the checkpointer will requeue the Slurm job at the end of its checkpointing step.
Set up a checkpointer.
Parameters:
-
run_id(str) –A unique identifier for this run.
-
model(Module) –The model that is trained and checkpointed.
-
optimizer(Union[Optimizer, None]) –The optimizer that is used for training and should be checkpointed. Use
Noneto explicitly ignore the optimizer. This can be useful if your optimizer does not implement.state_dictand.load_state_dict. -
lr_scheduler(Optional[LRScheduler], default:None) –The learning rate scheduler that is used for training. If
None, no learning rate scheduler is assumed. Default:None. -
scaler(Optional[GradScaler], default:None) –The gradient scaler that is used when training in mixed precision. If
None, no gradient scaler is assumed. Default:None. -
savedir(str, default:'checkpoints') –Directory to store checkpoints in. Default:
'checkpoints'. -
verbose(bool, default:False) –Whether to print messages about saving and loading checkpoints. Default:
False
Source code in wandb_preempt/checkpointer.py
load_latest_checkpoint
Load the latest checkpoint and set random number generator states.
Updates the model, optimizer, lr scheduler, and gradient scaler states passed at initialization.
Parameters:
-
weights_only(bool, default:True) –Whether to only unpickle objects that are safe to unpickle. If
True, the only types that will be loaded are tensors, primitive types, dictionaries and types added viatorch.serialization.add_safe_globals(). Seetorch.loadfor more information. Default:True. -
**kwargs–Additional keyword arguments to pass to the
torch.loadfunction.
Returns:
-
loaded_step(Union[int, None]) –The index of the checkpoint that was loaded, or
Noneif no checkpoint was found. -
extra_info(Dict) –Extra information that was passed by the user to the
stepfunction when the checkpoint was saved, or an empty dictionary if there is no extra information.
Source code in wandb_preempt/checkpointer.py
step
Perform a checkpointing step.
Save the checkpoint. If we were pre-empted we requeue the job and exit the training script after saving.
Parameters:
-
extra_info(Optional[Dict], default:None) –Additional information to save in the checkpoint. This dictionary is returned when loading the latest checkpoint with
Checkpointer.load_latest_checkpoint. By default, an empty dictionary is saved.
Source code in wandb_preempt/checkpointer.py
remove_checkpoints
Remove checkpoints.
Parameters:
-
keep_latest(bool, default:False) –Whether to keep the latest checkpoint. Default:
False.
Raises:
-
RuntimeError–If a non-
.ptfile is found in the checkpoint directory.