import os.path as osp
from typing import Type, Union, cast
import numpy as np
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
Tensor = torch.Tensor
ImageType = Union[Tensor, np.ndarray]
[docs]class Logger:
"""A Tensorboard-based logger."""
[docs] def __init__(self, log_dir: str, force_write: bool = False) -> None:
"""Constructor.
Args:
log_dir: The directory in which to store Tensorboard logs.
force_write: Whether to force write to an already existing log dir.
Set to `True` if resuming training.
"""
# Setup the summary writer.
if osp.exists(log_dir) and not force_write:
raise ValueError(
"You might be overwriting a directory that already "
"has train_logs. Please provide a new experiment name "
"or set --resume to True when launching train script."
)
self._writer = SummaryWriter(log_dir)
[docs] def close(self) -> None:
self._writer.close()
[docs] def flush(self) -> None:
self._writer.flush()
[docs] def log_scalar(
self,
scalar: Union[Tensor, float],
global_step: int,
name: str,
prefix: str = "",
) -> None:
"""Log a scalar value.
Args:
scalar: A scalar `torch.Tensor` or float.
global_step: The training iteration step.
name: The name of the logged scalar.
prefix: A prefix to prepend to the logged scalar.
"""
if isinstance(scalar, torch.Tensor):
if cast(torch.Tensor, scalar).ndim > 1:
raise ValueError("Tensor must be scalar-valued.")
if cast(torch.Tensor, scalar).ndim == 1:
if cast(torch.Tensor, scalar).shape != torch.Size([1]):
raise ValueError("Tensor must be scalar-valued.")
scalar = cast(torch.Tensor, scalar).item()
assert np.isscalar(scalar), "Not a scalar."
msg = "/".join([prefix, name]) if prefix else name
self._writer.add_scalar(msg, scalar, global_step)
[docs] def log_image(
self,
image: ImageType,
global_step: int,
name: str,
prefix: str = "",
nrow: int = 5,
) -> None:
"""Log an image or batch of images.
Args:
image: A numpy ndarray or a torch Tensor. If the image is 4D (i.e.
batched), it will be converted to a 3D image using make_grid.
The numpy array should be in channel-last format while the torch
Tensor should be in channel-first format.
global_step: The training iteration step.
name: The name of the logged image(s).
prefix: A prefix to prepend to the logged image(s).
nrow: The number of images displayed in each row of the grid if the
input image is 4D.
"""
msg = "/".join([prefix, name]) if prefix else name
assert image.ndim in [3, 4], "Must be an image or batch of images."
if image.ndim == 4:
if isinstance(image, np.ndarray):
image = torch.from_numpy(image).permute(0, 3, 1, 2)
image = torchvision.utils.make_grid(image, nrow=nrow)
else:
if isinstance(image, np.ndarray):
image = torch.from_numpy(image).permute(2, 0, 1)
self._writer.add_image(msg, image, global_step, dataformats="CHW")
[docs] def log_video(
self,
video,
global_step: int,
name: str,
prefix: str = "",
fps: int = 4,
) -> None:
"""Log a sequence of images or a batch of sequence of images.
Args:
video: A torch Tensor or numpy ndarray. The numpy array should be in
channel-last format while the torch Tensor should be in
channel-first format. Should be either a single sequence of
images of shape (T, CHW/HWC) or a batch of sequences of shape
(B, T, CHW/HWC). The batch of sequences will get converted to
one grid sequence of images.
global_step: The training iteration step.
name: The name of the logged video(s).
prefix: A prefix to prepend to the logged video(s).
fps: The frames per second.
"""
msg = f"{prefix}/image/{name}"
if video.ndim not in [4, 5]:
raise ValueError("Must be a video or batch of videos.")
if video.ndim == 4:
if isinstance(video, np.ndarray):
if video.shape[-1] != 3:
raise TypeError("Numpy array should have THWC format.")
# (T, H, W, C) -> (T, C, H, W).
video = torch.from_numpy(video).permute(0, 3, 1, 2)
elif isinstance(video, torch.Tensor):
if video.shape[1] != 3:
raise TypeError("Torch tensor should have TCHW format.")
video = video.unsqueeze(0) # (T, C, H, W) -> (1, T, C, H, W).
else:
if isinstance(video, np.ndarray):
if video.shape[-1] != 3:
raise TypeError("Numpy array should have BTHWC format.")
# (B, T, H, W, C) -> (B, T, C, H, W).
video = torch.from_numpy(video).permute(0, 1, 4, 2, 3)
elif isinstance(video, torch.Tensor):
if video.shape[2] != 3:
raise TypeError("Torch tensor should have BTCHW format.")
self._writer.add_video(msg, video, global_step, fps=fps)
[docs] def log_learning_rate(
self,
optimizer: Type[torch.optim.Optimizer],
global_step: int,
prefix: str = "",
) -> None:
"""Log the learning rate.
Args:
optimizer: An optimizer.
global_step: The training iteration step.
"""
if not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError("Optimizer must be an instance of torch.optim.Optimizer.")
for param_group in optimizer.param_groups:
lr = param_group["lr"]
self.log_scalar(lr, global_step, "learning_rate", prefix)