torchkit.checkpoint
- class torchkit.checkpoint.Checkpoint(**kwargs)[source]
Save and restore PyTorch objects implementing a state_dict method.
- Return type
None
- __init__(**kwargs)[source]
Constructor.
Accepts keyword arguments whose values are objects that contain a state_dict attribute and thus can be serialized to disk.
- Parameters
kwargs – Keyword arguments are set as attributes of this object, and are saved with the checkpoint. Values must have a state_dict attribute.
- Raises
ValueError – If objects in kwargs do not have a state_dict attribute.
- Return type
None
- class torchkit.checkpoint.CheckpointManager(directory, max_to_keep=10, **checkpointables)[source]
Periodically save PyTorch checkpointables (any object that implements a state_dict method) to disk and restore them to resume training.
Note: This is a re-implementation of 2.
Example usage:
from torchkit.checkpoint import CheckpointManager # Create a checkpoint manager instance. checkpoint_manager = checkpoint.CheckpointManager( checkpoint_dir, device, model=model, optimizer=optimizer, ) # Restore last checkpoint if it exists. global_step = checkpoint_manager.restore_or_initialize() for global_step in range(1000): # forward pass + loss computation # Save a checkpoint every N iters. if not global_step % N: checkpoint_manager.save(global_step)
- Parameters
directory (str) –
max_to_keep (int) –
checkpointables (Any) –
- Return type
None
- __init__(directory, max_to_keep=10, **checkpointables)[source]
Constructor.
- Parameters
directory (
str
) – The directory in which checkpoints will be saved.max_to_keep (
int
) – The maximum number of checkpoints to keep. Amongst all saved checkpoints, checkpoints will be deleted oldest first, until max_to_keep remain.checkpointables (
Any
) – Keyword args with checkpointable PyTorch objects.
- Return type
None
- static list_checkpoints(directory)[source]
List all checkpoints in a checkpoint directory.
- Return type
List
[Path
]- Parameters
directory (Union[pathlib.Path, str]) –
- load_checkpoint_at(global_step)[source]
Load a checkpoint at a given global step.
- Return type
None
- Parameters
global_step (int) –
- restore_or_initialize()[source]
Restore items in checkpoint from the latest checkpoint file.
- Return type
int
- Returns
The global iteration step. This is parsed from the latest checkpoint file if one is found, else 0 is returned.
- save(global_step)[source]
Create a new checkpoint.
- Parameters
global_step (
int
) – The iteration number which will be used to name the checkpoint.- Return type
None
- property latest_checkpoint: Optional[pathlib.Path]
Get the last saved checkpoint.
- Return type
Optional
[Path
]