池化层 (Pooling) 降低维度, 缩减模型大小,提高计算速度. 另外一个作用可以缓解卷积层对位置的敏感性.
池化层主要有两种:
-
最大池化
-
平均池化
-
池化层计算 {#title-0} ===================
最大池化:
- max(0, 1, 3, 4)
- max(1, 2, 4, 5)
- max(3, 4, 6, 7)
- max(4, 5, 7, 8)
平均池化:
-
mean(0, 1, 3, 4)
-
mean(1, 2, 4, 5)
-
mean(3, 4, 6, 7)
-
mean(4, 5, 7, 8)
-
Stride {#title-1} ====================
最大池化:
- max(0, 1, 4, 5)
- max(2, 3, 6, 7)
- max(8, 9, 12, 13)
- max(10, 11, 14, 15)
平均池化:
-
mean(0, 1, 4, 5)
-
mean(2, 3, 6, 7)
-
mean(8, 9, 12, 13)
-
mean(10, 11, 14, 15)
-
Padding {#title-2} =====================
最大池化:
- max(0, 0, 0, 0)
- max(0, 0, 0, 1)
- max(0, 0, 1, 2)
- max(0, 0, 2, 0)
- ... 以此类推
平均池化:
-
mean(0, 0, 0, 0)
-
mean(0, 0, 0, 1)
-
mean(0, 0, 1, 2)
-
mean(0, 0, 2, 0)
-
... 以此类推
-
多通道池化计算 {#title-3} =====================
在处理多通道输入数据时,池化层对每个输入通道分别池化,而不是像卷积层那样将各个通道的输入相加。这意味着池化层的输出和输入的通道数是相等。
- PyTorch 池化 API 使用 {#title-4} ===============================
import torch
import torch.nn as nn
# 1. API 基本使用
def test01():
inputs = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]).float()
inputs = inputs.unsqueeze(0).unsqueeze(0)
# 1. 最大池化
# 输入形状: (N, C, H, W)
polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
output = polling(inputs)
print(output)
# 2. 平均池化
polling = nn.AvgPool2d(kernel_size=2, stride=1, padding=0)
output = polling(inputs)
print(output)
# 2. stride 步长
def test02():
inputs = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]).float()
inputs = inputs.unsqueeze(0).unsqueeze(0)
# 1. 最大池化
polling = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
output = polling(inputs)
print(output)
# 2. 平均池化
polling = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
output = polling(inputs)
print(output)
# 3. padding 填充
def test03():
inputs = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]).float()
inputs = inputs.unsqueeze(0).unsqueeze(0)
# 1. 最大池化
polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=1)
output = polling(inputs)
print(output)
# 2. 平均池化
polling = nn.AvgPool2d(kernel_size=2, stride=1, padding=1)
output = polling(inputs)
print(output)
# 4. 多通道池化
def test04():
inputs = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[10, 20, 30], [40, 50, 60], [70, 80, 90]],
[[11, 22, 33], [44, 55, 66], [77, 88, 99]]]).float()
inputs = inputs.unsqueeze(0)
# 最大池化
polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
output = polling(inputs)
print(output)
if __name__ == '__main__':
test04()