torchkit.checkpoint

class torchkit.checkpoint.Checkpoint(**kwargs)[source]

Save and restore PyTorch objects implementing a state_dict method.

Return type

None

__init__(**kwargs)[source]

Constructor.

Accepts keyword arguments whose values are objects that contain a state_dict attribute and thus can be serialized to disk.

Parameters

kwargs – Keyword arguments are set as attributes of this object, and are saved with the checkpoint. Values must have a state_dict attribute.

Raises

ValueError – If objects in kwargs do not have a state_dict attribute.

Return type

None

restore(save_path)[source]

Restore a state from a saved checkpoint.

Parameters

save_path (Union[str, Path]) – The filepath to the saved checkpoint.

Return type

bool

Returns

True if restoring was successful or partially (not all checkpointables could be restored) successful and False otherwise.

save(save_path)[source]

Save a state to disk.

Modified from brentyi/fannypack.

Parameters

save_path (Path) – The name of the checkpoint to save.

Return type

None

class torchkit.checkpoint.CheckpointManager(directory, max_to_keep=10, **checkpointables)[source]

Periodically save PyTorch checkpointables (any object that implements a state_dict method) to disk and restore them to resume training.

Note: This is a re-implementation of 2.

Example usage:

from torchkit.checkpoint import CheckpointManager

# Create a checkpoint manager instance.
checkpoint_manager = checkpoint.CheckpointManager(
    checkpoint_dir,
    device,
    model=model,
    optimizer=optimizer,
)

# Restore last checkpoint if it exists.
global_step = checkpoint_manager.restore_or_initialize()
for global_step in range(1000):
    # forward pass + loss computation

    # Save a checkpoint every N iters.
    if not global_step % N:
        checkpoint_manager.save(global_step)
Parameters
  • directory (str) –

  • max_to_keep (int) –

  • checkpointables (Any) –

Return type

None

__init__(directory, max_to_keep=10, **checkpointables)[source]

Constructor.

Parameters
  • directory (str) – The directory in which checkpoints will be saved.

  • max_to_keep (int) – The maximum number of checkpoints to keep. Amongst all saved checkpoints, checkpoints will be deleted oldest first, until max_to_keep remain.

  • checkpointables (Any) – Keyword args with checkpointable PyTorch objects.

Return type

None

static list_checkpoints(directory)[source]

List all checkpoints in a checkpoint directory.

Return type

List[Path]

Parameters

directory (Union[pathlib.Path, str]) –

load_checkpoint_at(global_step)[source]

Load a checkpoint at a given global step.

Return type

None

Parameters

global_step (int) –

load_latest_checkpoint()[source]

Load the last saved checkpoint.

Return type

None

restore_or_initialize()[source]

Restore items in checkpoint from the latest checkpoint file.

Return type

int

Returns

The global iteration step. This is parsed from the latest checkpoint file if one is found, else 0 is returned.

save(global_step)[source]

Create a new checkpoint.

Parameters

global_step (int) – The iteration number which will be used to name the checkpoint.

Return type

None

property latest_checkpoint: Optional[pathlib.Path]

Get the last saved checkpoint.

Return type

Optional[Path]