Source code for torchkit.utils.module_stats

import torch


# Reference: https://stackoverflow.com/a/62508086
[docs]def get_total_params( model: torch.nn.Module, trainable: bool = True, print_table: bool = False, ) -> int: """Get the total number of parameters in a PyTorch model. Example usage:: class SimpleMLP(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(3, 16) self.fc2 = nn.Linear(16, 2) def forward(self, x: torch.Tensor) -> torch.Tensor: out = F.relu(self.fc1(x)) return self.fc2(out) net = SimpleMLP() num_params = torch_utils.get_total_params(net, print_table=True) # prints the following: +------------+------------+ | Modules | Parameters | +------------+------------+ | fc1.weight | 48 | | fc1.bias | 16 | | fc2.weight | 32 | | fc2.bias | 2 | +------------+------------+ Total Trainable Params: 98 Args: model (torch.nn.Module): The pytorch model. trainable (bool, optional): Only consider trainable parameters. Defaults to True. print_table (bool, optional): Print the parameters in a pretty table. Defaults to False. Returns: int: Either all model parameters or only the trainable ones. """ from prettytable import PrettyTable table = PrettyTable(["Modules", "Parameters"]) total_params = 0 for name, parameter in model.named_parameters(): if not parameter.requires_grad and trainable: continue param = parameter.numel() table.add_row([name, param]) total_params += param if print_table: print(table) print("Total Trainable Params: {:,}".format(total_params)) return total_params