Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

论文:https://arxiv.org/abs/2103.14030

代码:https://github.com/microsoft/Swin-Transformer

论文中提出了一种新型的Transformer架构(Swin Transformer),其利用滑动窗口和分层结构使得Swin Transformer成为了机器视觉领域新的Backbone,在图像分类、目标检测、语义分割等多种机器视觉任务中达到了SOTA水平。

目前Transformer应用到图像领域主要有两大挑战:

  • 视觉实体变化大,在不同场景下视觉Transformer性能未必很好
  • 图像分辨率高,像素点多,Transformer基于全局自注意力的计算导致计算量较

本文借鉴了CNN中的inductive bias,其中滑窗操作包括不重叠的local window,和重叠的cross-window。将注意力计算限制在一个窗口中,一方面能引入CNN卷积操作的局部性,另一方面能节省计算量

假设一张图片共有Swin Transformer: Hierarchical Vision Transformer using Shifted Windows个patches(每个patches是原图4*4像素区域),每个窗口包括Swin Transformer: Hierarchical Vision Transformer using Shifted Windows个patches.

原始Transformer self-attention计算复杂度 =  Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

在Swin Transformer中采用的是window self-attention,其计算复杂度为窗口计算复杂度*窗口数量,窗口数量=  Swin Transformer: Hierarchical Vision Transformer using Shifted Windows ,窗口计算复杂度=  Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

Swin Transformer self-attention计算复杂度 =  Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

计算复杂度由patches数量的平方关系降低到线性关系。

层次设计则是类似CNN随着网络变深,感受野变大的特性,将window内的多个patch变成一个patch,类似下采样。

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

整体结构

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。

Patch Partition+Linear Embedding

Patch Partition 将像素分辨率图像转换为patches分辨率的图像,每个patch视为一个token,特征就是patch范围内的RGB值的展开,token_feature = 48;代码通过二维卷积层,将stride,kernelsize设置为patch_size大小。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度。

Linear Embedding 将token_feature转换为需要的维度(Swin_T/C=96) 。

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size) # -> (img_size, img_size)
        patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        # 假设采取默认参数
        x = self.proj(x) # 出来的是(N, 96, 224/4, 224/4) 
        x = torch.flatten(x, 2) # 把HW维展开,(N, 96, 56*56)
        x = torch.transpose(x, 1, 2)  # 把通道维放到最后 (N, 56*56, 96)
        if self.norm is not None:
            x = self.norm(x)
        return x

Patch Merging

该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。

每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C) # (B, H*W, C)->(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 Window Partition/Reverse

window partition函数是用于对张量划分窗口,指定窗口大小。将原本的张量从 N H W C, 划分成 num_windows*B, window_size, window_size, C。window reverse函数则是对应的逆过程。这两个函数分别用在windows attention前后。

Window Attention

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 

传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 主要区别是在原始计算Attention的公式中的Q,K时加入了相对位置编码。实验有证明相对位置编码的加入提升了模型性能,所以这里主要讲一下相对位置编码。

假设window_size = 2*2即每个窗口有4个patch ,如图1所示,在计算self-attention时,每个patch都要与所有的ptch计算QK值,如图6所示,当位置1的patch计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的patch为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 首先我们利用torch.arangetorch.meshgrid函数生成对应的坐标,这里我们以windowsize=2为例子

coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
"""
  (tensor([[0, 0],
           [1, 1]]), 
   tensor([[0, 1],
           [0, 1]]))
"""

 然后堆叠起来,展开为一个二维向量

coords = torch.stack(coords)  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
"""
tensor([[0, 0, 1, 1],
        [0, 1, 0, 1]])
"""

 利用广播机制,分别在第一维,第二维,插入一个维度,进行广播相减,得到 2, wh*ww, wh*ww的张量

relative_coords_first = coords_flatten[:, :, None]  # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 形状的张量

首先 Window Partition将原本的张量从 N H W C, 划分成 num_windows*B, window_size, window_size, C。然后经过self.qkv这个全连接层后,进行reshape,调整轴的顺序,得到形状为3, numWindows*B, num_heads, window_size*window_size, c//num_heads,并分配给q,k,v。再加上之前的相对位置编码,剩下就是跟transformer一样的softmax,dropout,与V矩阵乘,再经过一层全连接层和dropout。

Shifted Window Attention

前面的Window Attention是在每个窗口下计算注意力的,为了更好的和其他window进行信息交互,Swin Transformer还引入了shifted window操作。

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 论文中以 Swin Transformer: Hierarchical Vision Transformer using Shifted Windows向下取整的窗口重新对原图进行分割(这里M是Window Attention窗口的尺寸 ),并将之前没有联系的新窗口合并得到新的窗口划分方案,如图8所示,带来的问题就是窗口个数增加了,为了避免窗口增加导致的额外计算量并保证不重叠窗口间有关联,论文提出了cyclic shift方法,如下图所示:

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价。 
特征图移位通过torch.roll实现,并且给每个子窗口进行编码:

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 Attention Mask

希望在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 但实际上代码使用张量相减来得到mask矩阵的:

if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

这里代码写的还是很巧妙的,patch与patch的自相关对应mask之间的比较,所以广播再相减就可以得到结果,如下:

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

上一篇:gfp_mask


下一篇:BERT和GPT