简洁易懂的PyTorch版ResNet50复现代码

ResNet50网络架构

ResNet50的网络解构相对简单,没有涉及到复杂的组件,大概50行代码就能复现。但我每次想用它的时候都会忘点东西,比如Bottleneck的结构如何实现,ResNet50的几个阶段各包含几个块等等,想着得写一篇文章记录下,免得以后又重复搬砖。ResNet50的网络结构如下,论文中网络的输入为 3x224x224,先经过步长为 2 填充为 3 的 7x7 卷积 + BN + ReLU和步长为 2 填充为 1 的 3x3 最大池化,接着经过4个阶段,每个阶段包含的 Bottleneck 卷积块分别为3、4、6、3,最后经过步长为 1 填充为 0 的 7x7 均值池化、Flatten 和输入为 2048 维,输出为 1000 维的全连接层,经过 Softmax 操作后得到网络的分类概率预测。
简洁易懂的PyTorch版ResNet50复现代码

Bottleneck卷积块

Bottleneck卷积块是ResNet50核心的部分,ResNet50的每个阶段由若干Bottleneck组成,其中第一个Bottleneck的输入与输出通道数不一致,需要使用 1x1 卷积 + BN 映射 Shortcut 后相加,其余的Bottleneck则是直接将 Shortcut 进行相加。包含与不包含1x1映射的Bottleneck结构分别如下所示:
简洁易懂的PyTorch版ResNet50复现代码

PyTorch复现代码

# ResNet50.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, groups=1, activation=True):
        super(Conv, self).__init__()
        padding = kernel_size // 2 if padding is None else padding
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
                              padding, groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU(inplace=True) if activation else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, down_sample=False, groups=1):
        super(Bottleneck, self).__init__()
        stride = 2 if down_sample else 1
        mid_channels = out_channels // 4
        self.shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=stride, activation=False) \
            if in_channels != out_channels else nn.Identity()
        self.conv = nn.Sequential(*[
            Conv(in_channels, mid_channels, kernel_size=1, stride=1),
            Conv(mid_channels, mid_channels, kernel_size=3, stride=stride, groups=groups),
            Conv(mid_channels, out_channels, kernel_size=1, stride=1, activation=False)
        ])

    def forward(self, x):
        y = self.conv(x) + self.shortcut(x)
        return F.relu(y, inplace=True)

class ResNet50(nn.Module):
    def __init__(self, num_classes):
        super(ResNet50, self).__init__()
        self.stem = nn.Sequential(*[
            Conv(3, 64, kernel_size=7, stride=2),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        ])
        self.stages = nn.Sequential(*[
            self._make_stage(64, 256, down_sample=False, num_blocks=3),
            self._make_stage(256, 512, down_sample=True, num_blocks=4),
            self._make_stage(512, 1024, down_sample=True, num_blocks=6),
            self._make_stage(1024, 2048, down_sample=True, num_blocks=3),
        ])
        self.head = nn.Sequential(*[
            nn.AvgPool2d(kernel_size=7, stride=1, padding=0),
            nn.Flatten(start_dim=1, end_dim=-1),
            nn.Linear(2048, num_classes)
        ])

    def _make_stage(self, in_channels, out_channels, down_sample, num_blocks):
        layers = [Bottleneck(in_channels, out_channels, down_sample=down_sample)]
        for _ in range(1, num_blocks):
            layers.append(Bottleneck(out_channels, out_channels, down_sample=False))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.head(self.stages(self.stem(x)))

if __name__ == "__main__":
    inputs = torch.rand((8, 3, 224, 224)).cuda()
    model = ResNet50(num_classes=1000).cuda().train()
    outputs = model(inputs)
    print(outputs.shape)
上一篇:语义分割中的nonlocal[2]-OCnet


下一篇:redis之发布与订阅