Source code for torchkit.layers

from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F

ConvType = Union[torch.nn.modules.conv.Conv2d, torch.nn.modules.conv.Conv3d]
Tensor = torch.Tensor


def _conv(
    dim: int,
    in_channels: int,
    out_channels: int,
    kernel_size: int = 3,
    stride: int = 1,
    dilation: int = 1,
    bias: bool = True,
) -> ConvType:
    """`same` convolution, i.e. output shape equals input shape.

    Args:
        dim: The dimension of the convolution: 2 is conv2d, 3 is conv3d.
        in_planes: The number of input feature maps.
        out_planes: The number of output feature maps.
        kernel_size: The filter size.
        stride: The filter stride.
        dilation: The filter dilation factor.
        bias: Whether to add a bias.
    """
    assert dim in [2, 3], "[!] Only 2D and 3D convolution supported."
    conv = nn.Conv2d if dim == 2 else nn.Conv3d

    # Compute new filter size after dilation and necessary padding for `same`
    # output size.
    dilated_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size
    same_padding = (dilated_kernel_size - 1) // 2

    return conv(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=same_padding,
        dilation=dilation,
        bias=bias,
    )


[docs]def conv2d(*args, **kwargs) -> torch.nn.modules.conv.Conv2d: """`same` 2D convolution, i.e. output shape equals input shape. Args: in_planes: The number of input feature maps. out_planes: The number of output feature maps. kernel_size: The filter size. stride: The filter stride. dilation: The filter dilation factor. bias: Whether to add a bias. """ return _conv(2, *args, **kwargs)
[docs]def conv3d(*args, **kwargs) -> torch.nn.modules.conv.Conv3d: """`same` 3D convolution, i.e. output shape equals input shape. Args: in_planes: The number of input feature maps. out_planes: The number of output feature maps. kernel_size: The filter size. stride: The filter stride. dilation: The filter dilation factor. bias: Whether to add a bias. """ return _conv(3, *args, **kwargs)
[docs]class Flatten(nn.Module): """Flattens convolutional feature maps for fully-connected layers. This is a convenience module meant to be plugged into a `torch.nn.Sequential` model. Example usage:: import torch.nn as nn from torchkit import layers # Assume an input of shape (3, 28, 28). net = nn.Sequential( layers.conv2d(3, 8, kernel_size=3), nn.ReLU(), layers.conv2d(8, 16, kernel_size=3), nn.ReLU(), layers.Flatten(), nn.Linear(28*28*16, 256), nn.ReLU(), nn.Linear(256, 2), ) """
[docs] def __init__(self): super().__init__()
def forward(self, x: Tensor) -> Tensor: return x.view(x.shape[0], -1)
[docs]class SpatialSoftArgmax(nn.Module): """Spatial softmax as defined in `1`_. Concretely, the spatial softmax of each feature map is used to compute a weighted mean of the pixel locations, effectively performing a soft arg-max over the feature dimension. .. _1: https://arxiv.org/abs/1504.00702 """
[docs] def __init__(self, normalize: bool = False): """Constructor. Args: normalize: Whether to use normalized image coordinates, i.e. coordinates in the range `[-1, 1]`. """ super().__init__() self.normalize = normalize
def _coord_grid( self, h: int, w: int, device: torch.device, ) -> Tensor: if self.normalize: return torch.stack( torch.meshgrid( torch.linspace(-1, 1, w, device=device), torch.linspace(-1, 1, h, device=device), ) ) return torch.stack( torch.meshgrid( torch.arange(0, w, device=device), torch.arange(0, h, device=device), ) ) def forward(self, x: Tensor) -> Tensor: assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)." # Compute a spatial softmax over the input: # Given an input of shape (B, C, H, W), reshape it to (B*C, H*W) then # apply the softmax operator over the last dimension. b, c, h, w = x.shape softmax = F.softmax(x.view(-1, h * w), dim=-1) # Create a meshgrid of normalized pixel coordinates. xc, yc = self._coord_grid(h, w, x.device) # Element-wise multiply the x and y coordinates with the softmax, then # sum over the h*w dimension. This effectively computes the weighted # mean x and y locations. x_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True) y_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True) # Concatenate and reshape the result to (B, C*2) where for every feature # we have the expected x and y pixel locations. return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2)
class _GlobalMaxPool(nn.Module): """Global max pooling layer.""" def __init__(self, dim): super().__init__() if dim == 1: self._pool = F.max_pool1d elif dim == 2: self._pool = F.max_pool2d elif dim == 3: self._pool = F.max_pool3d else: raise ValueError("{}D is not supported.") def forward(self, x: Tensor) -> Tensor: out = self._pool(x, kernel_size=x.size()[2:]) for _ in range(len(out.shape[2:])): out.squeeze_(dim=-1) return out
[docs]class GlobalMaxPool1d(_GlobalMaxPool): """Global max pooling operation for temporal or 1D data."""
[docs] def __init__(self): super().__init__(dim=1)
[docs]class GlobalMaxPool2d(_GlobalMaxPool): """Global max pooling operation for spatial or 2D data."""
[docs] def __init__(self): super().__init__(dim=2)
[docs]class GlobalMaxPool3d(_GlobalMaxPool): """Global max pooling operation for 3D data."""
[docs] def __init__(self): super().__init__(dim=3)
class _GlobalAvgPool(nn.Module): """Global average pooling layer.""" def __init__(self, dim): super().__init__() if dim == 1: self._pool = F.avg_pool1d elif dim == 2: self._pool = F.avg_pool2d elif dim == 3: self._pool = F.avg_pool3d else: raise ValueError("{}D is not supported.") def forward(self, x: Tensor) -> Tensor: out = self._pool(x, kernel_size=x.size()[2:]) for _ in range(len(out.shape[2:])): out.squeeze_(dim=-1) return out
[docs]class GlobalAvgPool1d(_GlobalAvgPool): """Global average pooling operation for temporal or 1D data."""
[docs] def __init__(self): super().__init__(dim=1)
[docs]class GlobalAvgPool2d(_GlobalAvgPool): """Global average pooling operation for spatial or 2D data."""
[docs] def __init__(self): super().__init__(dim=2)
[docs]class GlobalAvgPool3d(_GlobalAvgPool): """Global average pooling operation for 3D data."""
[docs] def __init__(self): super().__init__(dim=3)
[docs]class CausalConv1d(nn.Conv1d): """A causal a.k.a. masked 1D convolution."""
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, bias: bool = True, ): """Constructor. Args: in_channels: The number of input channels. out_channels: The number of output channels. kernel_size: The filter size. stride: The filter stride. dilation: The filter dilation factor. bias: Whether to add the bias term or not. """ self.__padding = (kernel_size - 1) * dilation super().__init__( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=self.__padding, dilation=dilation, bias=bias, )
def forward(self, x: Tensor) -> Tensor: res = super().forward(x) if self.__padding != 0: return res[:, :, : -self.__padding] return res