How does adaptive pooling in pytorch work?

In general, pooling reduces dimensions. If you want to increase dimensions, you might want to look at interpolation.

Anyway, let’s talk about adaptive pooling in general. You can look at the source code here. Some claimed that adaptive pooling is the same as standard pooling with stride and kernel size calculated from input and output size. Specifically, the following parameters are used:

  1. Stride = (input_size//output_size)
  2. Kernel size = input_size - (output_size-1)*stride
  3. Padding = 0

These are inversely worked from the pooling formula. While they DO produce output of the desired size, its output is not necessarily the same as that of adaptive pooling. Here is a test snippet:

import torch
import torch.nn as nn

in_length = 5
out_length = 3

x = torch.arange(0, in_length).view(1, 1, -1).float()
print(x)

stride = (in_length//out_length)
avg_pool = nn.AvgPool1d(
        stride=stride,
        kernel_size=(in_length-(out_length-1)*stride),
        padding=0,
    )
adaptive_pool = nn.AdaptiveAvgPool1d(out_length)

print(avg_pool.stride, avg_pool.kernel_size)

y_avg = avg_pool(x)
y_ada = adaptive_pool(x)

print(y_avg)
print(y_ada)

Output:

tensor([[[0., 1., 2., 3., 4.]]])
(1,) (3,)
tensor([[[1., 2., 3.]]])
tensor([[[0.5000, 2.0000, 3.5000]]])
Error:  1.0

Average pooling pools from elements (0, 1, 2), (1, 2, 3) and (2, 3, 4).

Adaptive pooling pools from elements (0, 1), (1, 2, 3) and (3, 4). (Change the code a bit to see that it is not pooling from (2) only)

  • You can tell adaptive pooling tries to reduce overlapping in pooling.
  • The difference can be mitigated using padding with count_include_pad=True, but in general I don’t think they can be exactly the same for 2D or higher for all input/output sizes. I would imagine using different paddings for left/right. This is not supported in pooling layers for the moment.
  • From a practical perspective it should not matter much.
  • Check the code for actual implementation.

Leave a Comment

Hata!: SQLSTATE[HY000] [1045] Access denied for user 'divattrend_liink'@'localhost' (using password: YES)