Source code for torchkit.checkpoint

import logging
import os
import signal
import tempfile
from pathlib import Path
from typing import Any, List, Optional, Union

import torch

from .experiment import unique_id


def get_files(
    d: Path,
    pattern: str,
    sort_lexicographical: bool = False,
    sort_numerical: bool = False,
) -> List[Path]:
    """Return a list of files in a given directory.

    Args:
        d: The path to the directory.
        pattern: The wildcard to filter files with.
        sort_lexicographical: Lexicographical sort.
        sort_numerical: Numerical sort.
    """
    files = d.glob(pattern)
    if sort_lexicographical:
        return sorted(files, key=lambda x: x.stem)
    if sort_numerical:
        return sorted(files, key=lambda x: int(x.stem))
    return list(files)


[docs]class Checkpoint: """Save and restore PyTorch objects implementing a `state_dict` method."""
[docs] def __init__(self, **kwargs) -> None: """Constructor. Accepts keyword arguments whose values are objects that contain a `state_dict` attribute and thus can be serialized to disk. Args: kwargs: Keyword arguments are set as attributes of this object, and are saved with the checkpoint. Values must have a `state_dict` attribute. Raises: ValueError: If objects in `kwargs` do not have a `state_dict` attribute. """ for k, v in sorted(kwargs.items()): if not getattr(v, "state_dict"): raise ValueError(f"{k} does not have a state_dict attribute.") setattr(self, k, v)
[docs] def save(self, save_path: Path) -> None: """Save a state to disk. Modified from brentyi/fannypack. Args: save_path: The name of the checkpoint to save. """ # Ignore ctrl+c while saving. try: orig_handler = signal.getsignal(signal.SIGINT) signal.signal(signal.SIGINT, lambda _sig, _frame: None) except ValueError: # Signal throws a ValueError if we're not in the main thread. orig_handler = None # Create a snapshot of the current state. save_dict = dict() for k, v in self.__dict__.items(): save_dict[k] = v.state_dict() with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) / "tmp.ckpt" torch.save(save_dict, tmp_path) # `rename` is POSIX-compliant and thus, is an atomic operation. # Ref: https://docs.python.org/3/library/os.html#os.rename os.rename(tmp_path, save_path) tmp_path = save_path.parent / f"tmp-{unique_id()}.ckpt" torch.save(save_dict, tmp_path) os.rename(tmp_path, save_path) # Restore SIGINT handler. if orig_handler is not None: signal.signal(signal.SIGINT, orig_handler)
[docs] def restore(self, save_path: Union[str, Path]) -> bool: """Restore a state from a saved checkpoint. Args: save_path: The filepath to the saved checkpoint. Returns: True if restoring was successful or partially (not all checkpointables could be restored) successful and False otherwise. """ try: state = torch.load(Path(save_path), map_location="cpu") for name, state_dict in state.items(): if not hasattr(self, name): logging.warning( f"{name} in saved checkpoint not in checkpoint to " "reload. Skipping it." ) continue getattr(self, name).load_state_dict(state_dict) return True except Exception as e: print(e) return False
# TODO(kevin): Add saving of best checkpoint based on specified metric.
[docs]class CheckpointManager: """ Periodically save PyTorch checkpointables (any object that implements a `state_dict` method) to disk and restore them to resume training. Note: This is a re-implementation of `2`_. Example usage:: from torchkit.checkpoint import CheckpointManager # Create a checkpoint manager instance. checkpoint_manager = checkpoint.CheckpointManager( checkpoint_dir, device, model=model, optimizer=optimizer, ) # Restore last checkpoint if it exists. global_step = checkpoint_manager.restore_or_initialize() for global_step in range(1000): # forward pass + loss computation # Save a checkpoint every N iters. if not global_step % N: checkpoint_manager.save(global_step) .. _2: https://www.tensorflow.org/api_docs/python/tf/train/CheckpointManager/ """
[docs] def __init__( self, directory: str, max_to_keep: int = 10, **checkpointables: Any, ) -> None: """Constructor. Args: directory: The directory in which checkpoints will be saved. max_to_keep: The maximum number of checkpoints to keep. Amongst all saved checkpoints, checkpoints will be deleted oldest first, until `max_to_keep` remain. checkpointables: Keyword args with checkpointable PyTorch objects. """ assert max_to_keep > 0, "max_to_keep should be a positive integer." self.directory = Path(directory).absolute() self.max_to_keep = max_to_keep self.checkpoint = Checkpoint(**checkpointables) # Create checkpoint directory if it doesn't already exist. self.directory.mkdir(parents=True, exist_ok=True)
[docs] def restore_or_initialize(self) -> int: """Restore items in checkpoint from the latest checkpoint file. Returns: The global iteration step. This is parsed from the latest checkpoint file if one is found, else 0 is returned. """ ckpts = CheckpointManager.list_checkpoints(self.directory) if not ckpts: return 0 last_ckpt = ckpts[-1] status = self.checkpoint.restore(last_ckpt) if not status: logging.info("Could not restore latest checkpoint file.") return 0 return int(last_ckpt.stem)
[docs] def save(self, global_step: int) -> None: """Create a new checkpoint. Args: global_step: The iteration number which will be used to name the checkpoint. """ save_path = self.directory / f"{global_step}.ckpt" self.checkpoint.save(save_path) self._trim_checkpoints()
def _trim_checkpoints(self) -> None: """Trim older checkpoints until `max_to_keep` remain.""" # Get a list of checkpoints in reverse global_step order. ckpts = CheckpointManager.list_checkpoints(self.directory)[::-1] # Remove until `max_to_keep` remain. while len(ckpts) - self.max_to_keep > 0: ckpts.pop().unlink()
[docs] def load_latest_checkpoint(self) -> None: """Load the last saved checkpoint.""" self.checkpoint.restore(self.latest_checkpoint)
[docs] def load_checkpoint_at(self, global_step: int) -> None: """Load a checkpoint at a given global step.""" ckpt_name = self.directory / f"{global_step}.ckpt" if ckpt_name not in CheckpointManager.list_checkpoints(self.directory): raise ValueError(f"No checkpoint found at step {global_step}.") self.checkpoint.restore(ckpt_name)
@property def latest_checkpoint(self) -> Optional[Path]: """Get the last saved checkpoint.""" ckpts = CheckpointManager.list_checkpoints(self.directory) if not ckpts: raise ValueError(f"No checkpoints found in {self.directory}.") return ckpts[-1]
[docs] @staticmethod def list_checkpoints(directory: Union[Path, str]) -> List[Path]: """List all checkpoints in a checkpoint directory.""" return get_files(Path(directory), "*.ckpt", sort_numerical=True)