torchkit.utils.module_stats
- torchkit.utils.module_stats.get_total_params(model, trainable=True, print_table=False)[source]
Get the total number of parameters in a PyTorch model.
Example usage:
class SimpleMLP(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(3, 16) self.fc2 = nn.Linear(16, 2) def forward(self, x: torch.Tensor) -> torch.Tensor: out = F.relu(self.fc1(x)) return self.fc2(out) net = SimpleMLP() num_params = torch_utils.get_total_params(net, print_table=True) # prints the following: +------------+------------+ | Modules | Parameters | +------------+------------+ | fc1.weight | 48 | | fc1.bias | 16 | | fc2.weight | 32 | | fc2.bias | 2 | +------------+------------+ Total Trainable Params: 98
- Parameters
model (torch.nn.Module) – The pytorch model.
trainable (bool, optional) – Only consider trainable parameters. Defaults to True.
print_table (bool, optional) – Print the parameters in a pretty table. Defaults to False.
- Returns
Either all model parameters or only the trainable ones.
- Return type
int