我的编程空间,编程开发者的网络收藏夹
学习永远不晚

Swin-Transformer(原理 + 代码)详解

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

Swin-Transformer(原理 + 代码)详解

参考博文

图解Swin Transformer
Swin-Transformer网络结构详解
【机器学习】详解 Swin Transformer (SwinT)
论文下载

(二)代码的下载与配置

2.1、需要的安装包

官方源码下载
学习的话,请下载Image Classification的代码,配置相对简单,其他的配置会很麻烦。如下图所示:
在这里插入图片描述

Install :
pytorch安装:感觉pytorch > 1.4版本都没问题的。
2、pip install timm==0.3.2(最新版本也行)
1、pip install Apex

  • win 10系统下安装NVIDIA apex

这个我认为windows安装可能会很啃。
1、首先在github下载源码https://github.com/NVIDIA/apex到本地文件夹
2、打开cmd命令窗口,切换到apex所在的文件夹
3、使用命令:python setup.py install 即可完成安装

注意事项: 可能会出现的问题:
setuptools有ModuleNotFoundError→更新setuptools
pip install --upgrade setuptools

  • linux系统下安装NVIDIA apex
git clone https://github.com/NVIDIA/apexcd apexpip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

2.2、代码运行配置

注意:不要用ImageNet数据集:显卡可能会受不了,就是为了学习swin代码对吧,可以自己找一个小的ImageNet的数据集。

2.2.1、代码配置

首先运行main.py文件,如下图:
在这里插入图片描述
再点击main.py配置:
在这里插入图片描述
最后在下图Parameters处填入:

--cfg configs/swin_tiny_patch4_window7_224.yaml --data-path imagenet --local_rank 0 --batch-size 2

在这里插入图片描述

2.2.2、本人运行报错修改

报错1如下:
Swin transformer TypeError: __init__() got an unexpected keyword argument ‘t_mul‘
报错2如下:
from timm.data.transforms import _pil_interp无法导入_pil_interp

pip install timm==0.3.2(最新版本也行)

但是我安装最新版本后:from timm.data.transforms import _pil_interp无法导入_pil_interp,然后我查看了timm.data 中的transforms.py文件,完全就没有定义_pil_interp。完整的timm.data 中的transforms.py文件我在下面也把这个文件_pil_interp代码复制在下面,可以自行补充_pil_interp。

def _pil_interp(method):    if method == 'bicubic':        return Image.BICUBIC    elif method == 'lanczos':        return Image.LANCZOS    elif method == 'hamming':        return Image.HAMMINGif has_interpolation_mode:    _torch_interpolation_to_str = {        InterpolationMode.NEAREST: 'nearest',        InterpolationMode.BILINEAR: 'bilinear',        InterpolationMode.BICUBIC: 'bicubic',        InterpolationMode.BOX: 'box',        InterpolationMode.HAMMING: 'hamming',        InterpolationMode.LANCZOS: 'lanczos',    }    _str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()}else:    _pil_interpolation_to_torch = {}    _torch_interpolation_to_str = {}

完整的timm.data 中的transforms.py文件:界面如下图
在这里插入图片描述

报错3:如下图所示:

在这里插入图片描述
解决办法如下:

解决办法,删除Swin-Transformer/lr_scheduler.py的第24行‘t_mul=1.,’

在这里插入图片描述

(三)原理概括

下面PPT是对Swin-Transformer做了一个大概的概括,具体细节可以参考第四节代码部分。

在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述
在这里插入图片描述
在这里插入图片描述在这里插入图片描述在这里插入图片描述

(四)代码详解

注意:本人代码部分是按照Debug顺序进行编写的。并不是按照一个一个模块去分开讲解的,所以大家看起来可能会很按难受。这里推荐一篇博客图解Swin Transformer,是按照每个结构分开单独编写,容易理解,思路清晰。

准备工作

首先我们打开main.py和swin_transformer.py文件。

然后在swin_transformer.py中找到class SwinTransformer(nn.Module):类,在其def forward_features(self, x):下第一行插入断点,那么下面我们就开始一步一步debug吧。

    def forward_features(self, x):        print(x.shape)   # [2, 3, 224, 224], batch_size = 2        x = self.patch_embed(x)  # 详解在3.1节        print(x.shape)        if self.ape:    # self.ape = False不用考虑            x = x + self.absolute_pos_embed        x = self.pos_drop(x)  # 就是一个Droupout层        print(x.shape)  # [2, 3136, 96]        for layer in self.layers:            x = layer(x)            print(x.shape)        x = self.norm(x)  # B L C        print(x.shape)        x = self.avgpool(x.transpose(1, 2))  # B C 1        print(x.shape)        x = torch.flatten(x, 1)        print(x.shape)        return x

3.1、PatchEmbed

在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块,并嵌入到Embedding。
在每个Stage里,由Patch Merging多个Block组成。
其中Patch Merging模块主要在每个Stage一开始降低图片分辨率。
在这里插入图片描述

而Block具体结构如右图所示,主要是LayerNorm,MLP,Window Attention 和 Shifted Window Attention组成 (为了方便讲解,我会省略掉一些参数)

class PatchEmbed(nn.Module):    r""" Image to Patch Embedding    Args:        img_size (int): Image size.  Default: 224.        patch_size (int): Patch token size. Default: 4.        in_chans (int): Number of input image channels. Default: 3.        embed_dim (int): Number of linear projection output channels. Default: 96.        norm_layer (nn.Module, optional): Normalization layer. Default: None    """    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):        super().__init__()        img_size = to_2tuple(img_size)         patch_size = to_2tuple(patch_size)        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]        self.img_size = img_size        self.patch_size = patch_size        self.patches_resolution = patches_resolution        self.num_patches = patches_resolution[0] * patches_resolution[1]        self.in_chans = in_chans        self.embed_dim = embed_dim        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)        if norm_layer is not None:            self.norm = norm_layer(embed_dim)        else:            self.norm = None    def forward(self, x):        B, C, H, W = x.shape  # [2, 3, 224, 224]        # FIXME look at relaxing size constraints        assert H == self.img_size[0] and W == self.img_size[1], \            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."                # proj是先卷积,再flatten(2)把三四列变成一列(即56*56=3136)        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C        # x = torch.Size([2, 3136, 96])          # 56*56 = 3136个patch=tokens        # 每个patch或tokens的向量维度为96        print(x.shape) #4 3136 96 其中3136就是 224/4 * 224/4 相当于有这么长的序列,其中每个元素是96维向量        if self.norm is not None:            x = self.norm(x)        print(x.shape)        return x

其实只要看forward就行了。

x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C

这一行就是把原图[2, 3, 224, 224]转化为3136个patch,每个patch的维度等于96。
在这里插入图片描述

3.2、class SwinTransformerBlock()

class SwinTransformerBlock(nn.Module):    def forward(self, x):        H, W = self.input_resolution        B, L, C = x.shape        assert L == H * W, "input feature has wrong size"        shortcut = x   # x = [2,3136,96]        x = self.norm1(x)        x = x.view(B, H, W, C)  # (2, 56, 56, 96)        # cyclic shift        # 在第一次我们是W-MSA,没有滑动窗口,所以self.shift_size > 0 =False        if self.shift_size > 0:            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))        else:            shifted_x = x        # partition windows        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C        # W-MSA/SW-MSA        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C        # merge windows (把attention后的数据还原成原来输入的shape)        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C        # reverse cyclic shift        if self.shift_size > 0:            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))            print(x.shape)        else:            x = shifted_x        x = x.view(B, H * W, C)        print(x.shape)        # FFN        x = shortcut + self.drop_path(x)        print(x.shape)        x = x + self.drop_path(self.mlp(self.norm2(x)))        print(x.shape)        return x
第一部分:W-MSA部分和窗口的构建
x = x.view(B, H, W, C)  # (2, 56, 56, 96)# cyclic shift        # 在第一次我们是W-MSA,没有滑动窗口,所以self.shift_size > 0 =False        if self.shift_size > 0:            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))        else:            shifted_x = x        # partition windows        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
第一部分中的window_partition()类:(就是把序列转化为窗口)
def window_partition(x, window_size):    """    Args:        x: (B, H, W, C)        window_size (int): window size    Returns:        windows: (num_windows*B, window_size, window_size, C)    """    B, H, W, C = x.shape    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)  # (2,8,7,8,7,96):指把56*56的patch按照7*7的窗口划分    print(x.shape)  # (2,8,7,8,7,96)    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) # window的数量 H/7 * W/7 *batch    print(windows.shape)      # windows=(128, 7, 7, 96)      # 128 = batch_size * 8 * 8 = 128窗口的数量    # 7 = window_size 窗口的大小尺寸,说明每个窗口包含49个patch    return windows

在这里插入图片描述在这里插入图片描述

详解self.attn(x_windows, mask=self.attn_mask)

先定位到class WindowAttention(nn.Module):处,在其forward上打上断点,现在我们去看看吧。(下面我只复制了forward代码)
class WindowAttention(nn.Module):    def forward(self, x, mask=None):        """        Args:            x: input features with shape of (num_windows*B, N, C)            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None        """        B_, N, C = x.shape        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)        print(qkv.shape) # torch.Size([3, 128, 3, 49, 32])        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)        print(q.shape)  # torch.Size([128, 3, 49, 32])        print(k.shape)  # torch.Size([128, 3, 49, 32])        print(v.shape)  # torch.Size([128, 3, 49, 32])        q = q * self.scale  # q = [128, 3, 49, 32]        attn = (q @ k.transpose(-2, -1))        print(attn.shape) # torch.Size([128, 3, 49, 49])        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH        print(relative_position_bias.shape)         relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww        print(relative_position_bias.shape)        attn = attn + relative_position_bias.unsqueeze(0)        print(attn.shape)        if mask is not None:             nW = mask.shape[0]            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)            attn = attn.view(-1, self.num_heads, N, N)            attn = self.softmax(attn)        else:            attn = self.softmax(attn)        attn = self.attn_drop(attn)        print(attn.shape)        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)        print(x.shape)        x = self.proj(x) # 全连接层,用于整合新信息的        print(x.shape)        x = self.proj_drop(x)        print(x.shape)   # 还原成输入的形式[2,3136,96]        return x

qkv.shape = [3, 128, 3, 49, 32]
(1)3:是指Q、K、V三个
(2)128:是指128个windows
(3)3:是指Multi–Head = 3(多头注意力机制)
(4)49:是指每个窗口含有49个patchs,每个窗口的49个patchs之间要相互做self–attention
(5)32:是指经过多头后,每个head分配32个维度。

attn = (q @ k.transpose(-2, -1))
print(attn.shape) # torch.Size([128, 3, 49, 49])
就是正常的计算 α \alpha α相关性。

attn = attn + relative_position_bias.unsqueeze(0)
这是把attention( α \alpha α)和 位置编码进行相加,和ViT中在tokens加上位置编码。(具体如下图所示)
后面的代码和VIT中基本一样,详细请看本人上一篇博客ViT( Vision Transformer)详解
在这里插入图片描述
ViT的Attention公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V = s o f t m a x ( α d k ) V Attention(Q,K,V) = softmax(\frac {QK^T}{\sqrt{d_k}}) V= softmax(\frac {\boldsymbol{\alpha}}{\sqrt{d_k}}) V Attention(Q,K,V)=softmax(dk QKT)V=softmax(dk α)V:
Swin–Transformer公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k + B ) V = s o f t m a x ( α d k + B ) V Attention(Q,K,V) = softmax(\frac {QK^T}{\sqrt{d_k}} + B) V= softmax(\frac {\boldsymbol{\alpha}}{\sqrt{d_k}} + B) V Attention(Q,K,V)=softmax(dk QKT+B)V=softmax(dk α+B)V:
上面公式的主要区别是在原始计算Attention的公式中的Q,K时加入了相对位置编码B。后续实验有证明相对位置编码的加入提升了模型性能。

相对位置编码

由于论文中并没有详解讲解这个相对位置偏执,所以我自己根据阅读源码做了简单的总结。(主要借鉴了Swin-Transformer网络结构详解这篇博客)如下图,假设输入的feature map高宽都为2,那么首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是 ( 0 , 0 ) (0,0) (0,0),接下来再看看相对位置索引。首先看下蓝色的像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是 ( 0 , 1 ) (0,1) (0,1),则它相对蓝色像素的相对位置索引为 ( 0 , 0 ) − ( 0 , 1 ) = ( 0 , − 1 ) (0,0) - (0,1) = (0,-1) (0,0)(0,1)=(0,1),这里是严格按照源码中来讲的,请不要杠。那么同理可以得到其他位置相对蓝色像素的相对位置索引矩阵。同样,也能得到相对黄色,红色以及绿色像素的相对位置索引矩阵。接下来将每个相对位置索引矩阵按行展平,并拼接在一起可以得到下面的4x4矩阵 。

个人理解:
四个绝对位置编码分别为:(0,0)(0,1)1,0(1,0)(1,1),每个位置对应的相对位置为(0,0),我们看一下 4×4 4\times4 4×4矩阵第二行,蓝色对应黄色:【用真实位置编码坐标相减】(0,1)-(0,0) = (0,1), 红的对和黄色(0,1)-(1,0) = (-1,1),绿色对黄色:(0,1)-(1,1)=(-1,0),直接得到第二行所有元素。

在这里插入图片描述

在这里插入图片描述
代码过程如下:

coords_h = torch.arange(2)coords_w = torch.arange(2)coords = torch.meshgrid([coords_h, coords_w])print(coords)coords = torch.stack(coords)  # 2, Wh, Wwprint("1 1 1 "* 10)print(coords)coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Wwprint("2 2 2  "* 10)print(coords_flatten)relative_coords_first = coords_flatten[:, :, None]  # 2, wh*ww, 1print("3 3 3 "*10)print(relative_coords_first)relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*wwprint("4 4 4 "*10)print(relative_coords_second)relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 形状的张量print("5 5 5 "*10)print(relative_coords)relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2print("6 6 6 "*10)print(relative_coords)print(relative_coords.shape)

在这里插入图片描述在这里插入图片描述
请注意,我这里描述的一直是相对位置索引,并不是相对位置偏执参数。因为后面我们会根据相对位置索引去取对应的参数。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为( 0 , − 1 ) 。绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为( 0 , − 1 )。可以发现这两者的相对位置索引都是( 0 , − 1 ) ,所以他们使用的相对位置偏执参数都是一样的。其实讲到这基本已经讲完了,但在源码中作者为了方便把二维索引给转成了一维索引。具体这么转的呢,有人肯定想到,简单啊直接把行、列索引相加不就变一维了吗?比如上面的相对位置索引中有( 0 , − 1 ) 和( − 1 , 0 )在二维的相对位置索引中明显是代表不同的位置,但如果简单相加都等于-1那不就出问题了吗?接下来我们看看源码中是怎么做的。首先在原始的相对位置索引上加上M-1(M为窗口的大小,在本示例中M=2),加上之后索引中就不会有负数了。
在这里插入图片描述

relative_coords[:, :, 0] += 2 - 1print("7 7 7 "*10)print(relative_coords)relative_coords[:, :, 1] += 2 - 1print("8 8 8 "*10)print(relative_coords)

在这里插入图片描述在这里插入图片描述
接着将所有的行标都乘上2M-1。
在这里插入图片描述

relative_coords[:, :, 0] *= 2 * 2 - 1print("9 9 9 "*10)print(relative_coords)

在这里插入图片描述

最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现上述 0 + ( − 1 ) = ( − 1 ) + 0 0 + (-1) = (-1) + 0 0+(1)=(1)+0的问题了,是不是很神奇。
在这里插入图片描述
代码过程如下:

relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwprint("10 10 10 "* 3)print(relative_position_index)

在这里插入图片描述

刚刚上面也说了,之前计算的是相对位置索引,并不是相对位置偏执参数。真正使用到的可训练参数 B B B是保存在relative position bias table表里的,这个表的长度是等于 ( 2 M − 1 ) × ( 2 M − 1 ) ( 2 M − 1 ) × ( 2 M − 1 ) (2M1)×(2M1)的。那么上述公式中的相对位置偏执参数B是根据上面的相对位置索引表根据查relative position bias table表得到的,如下图所示。
在这里插入图片描述
以上过程结束,代表Swin–Transformer–Block中的第一部分(W–MSA)结束。返回的x = [2, 3136, 49]。如下图所示:
在这里插入图片描述

那么接下来我们要继续执行Swin--Transformer--Block中的第二部分(SW--MSA)。

在这里插入图片描述

与前一部份的Block的不同之处在于SW-MSA,有个滑动窗口、偏移量等新的东西加入。相同的部分我们就不再代码讲述,下面我们只看不同的部分。

首先我们看到这一部分:

# cyclic shift        if self.shift_size > 0:   # self.shift_size = 3,偏移量为3.            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))        else:            shifted_x = x

再看以下Mask部分:可以看到,一直到运行到attn部分,都和前面的W-MSA参数的size是一样的。

...  ...  ...  ...        attn = (q @ k.transpose(-2, -1))        print(attn.shape)...  ...  ...  ...               attn = attn + relative_position_bias.unsqueeze(0)        print(attn.shape)  # torch.Size([128, 3, 49, 49])print(attn.shape)  # torch.Size([128, 3, 49, 49])# mask部分,mask = self.attn_mask        if mask is not None:            nW = mask.shape[0]            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)            attn = attn.view(-1, self.num_heads, N, N)            attn = self.softmax(attn)        else:            attn = self.softmax(attn)

Shifted Window Attention,前面的Window Attention是在每个窗口下计算注意力的,为了更好的和其他window进行信息交互,Swin Transformer还引入了shifted window操作。下面看一下self.shift_size 的定义吧
在这里插入图片描述
左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即window的个数翻倍了,由原本四个窗口变成了9个窗口。
在实际代码里,我们是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价。
在这里插入图片描述
特征图移位操作

代码里对特征图移位是通过torch.roll来实现的,
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
下面是示意图:在这里插入图片描述
如果需要reverse cyclic shift(就是还原操作)的话只需把参数shifts设置为对应的正数值。

Attention Mask
我认为这是Swin Transformer的精华,通过设置合理的mask,让Shifted Window Attention在与Window Attention相同的窗口个数下,达到等价的计算结果。
首先我们对Shift Window后的每个窗口都给上index,并且做一个roll操作(window_size=2, shift_size=-1)在这里插入图片描述
我们希望在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果
最后正确的结果如下图所示:

在这里插入图片描述
而要想在原始四个窗口下得到正确的结果,我们就必须给Attention的结果加入一个mask(如上图最右边所示)相关代码如下:

slice(start,end)函数:方法可从已有数组中返回选定的元素,返回一个新数组,包含从start到end(不包含该元素)的数组元素

  • start参数:必须,规定从何处开始选取,如果为负数,规定从数组尾部算起的位置,-1是指最后一个元素。
  • end参数:可选(如果该参数没有指定,那么切分的数组包含从start倒数组结束的所有元素,如果这个参数为负数,那么规定是从数组尾部开始算起的元素)。
    在这里插入图片描述在这里插入图片描述
        if self.shift_size > 0:            # calculate attention mask for SW-MSA            H, W = self.input_resolution            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1            h_slices = (slice(0, -self.window_size),                        slice(-self.window_size, -self.shift_size),                        slice(-self.shift_size, None))            # h_slices = (slice(0, -7, None) ,slice(7, -3, None) ,slice(-3, None, None))            w_slices = (slice(0, -self.window_size),                        slice(-self.window_size, -self.shift_size),                        slice(-self.shift_size, None))            cnt = 0            for h in h_slices:                for w in w_slices:                    img_mask[:, h, w, :] = cnt                    cnt += 1            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))        else:            attn_mask = None

在这里插入图片描述

 mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1

各项细节如下图所示:
在这里插入图片描述

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

细节如下图所示:
在这里插入图片描述

attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

意思就是对于attn_mask不为0的部分填充为-100,有什么用呢?想一想softmax函数的计算公式,在这里插入图片描述
e − 100 ≈ 0 e^{-100}\approx 0 e1000 , 那么这样是不是等于忽略不同index(指下图中的0,1,2,···,8) QK计算结果。在这里插入图片描述

在这里插入图片描述

Downsample(下采样操作):Patch Merging

注意:这里的Patch Merging下采样操作用的可不是 1×1 1 \times 1 1×1卷积进行的下采样。
该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。

在CNN中,则是在每个Stage开始前用stride=2的卷积/池化层来降低分辨率。

每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍

class PatchMerging(nn.Module):    r""" Patch Merging Layer.    Args:        input_resolution (tuple[int]): Resolution of input feature.        dim (int): Number of input channels.        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm    """    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):        super().__init__()        self.input_resolution = input_resolution        self.dim = dim        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)        self.norm = norm_layer(4 * dim)    def forward(self, x):        """        x: B, H*W, C        """        H, W = self.input_resolution        B, L, C = x.shape        assert L == H * W, "input feature has wrong size"        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."        x = x.view(B, H, W, C)        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C        x = self.norm(x)        x = self.reduction(x)        return x

下面是一个示意图(输入张量N=1, H=W=8, C=1,不包含最后的全连接层调整)在这里插入图片描述

class BasicLayer(nn.Module):    def forward(self, x):        for blk in self.blocks:            if self.use_checkpoint:                x = checkpoint.checkpoint(blk, x)            else:                x = blk(x)        if self.downsample is not None:            x = self.downsample(x)        return x
整体结构SwinTransformer(),最后分类输出层的概述
class SwinTransformer(nn.Module):    r""" Swin Transformer        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -          https://arxiv.org/pdf/2103.14030    Args:        img_size (int | tuple(int)): Input image size. Default 224        patch_size (int | tuple(int)): Patch size. Default: 4        in_chans (int): Number of input image channels. Default: 3        num_classes (int): Number of classes for classification head. Default: 1000        embed_dim (int): Patch embedding dimension. Default: 96        depths (tuple(int)): Depth of each Swin Transformer layer.        num_heads (tuple(int)): Number of attention heads in different layers.        window_size (int): Window size. Default: 7        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None        drop_rate (float): Dropout rate. Default: 0        attn_drop_rate (float): Attention dropout rate. Default: 0        drop_path_rate (float): Stochastic depth rate. Default: 0.1        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False        patch_norm (bool): If True, add normalization after patch embedding. Default: True        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False    """    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,                 use_checkpoint=False, **kwargs):        super().__init__()        self.num_classes = num_classes        self.num_layers = len(depths)        self.embed_dim = embed_dim        self.ape = ape        self.patch_norm = patch_norm        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))        self.mlp_ratio = mlp_ratio        # split image into non-overlapping patches        self.patch_embed = PatchEmbed(            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,            norm_layer=norm_layer if self.patch_norm else None)        num_patches = self.patch_embed.num_patches        patches_resolution = self.patch_embed.patches_resolution        self.patches_resolution = patches_resolution        # absolute position embedding        if self.ape:            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))            trunc_normal_(self.absolute_pos_embed, std=.02)        self.pos_drop = nn.Dropout(p=drop_rate)        # stochastic depth        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule        # build layers        self.layers = nn.ModuleList()        for i_layer in range(self.num_layers):            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),   input_resolution=(patches_resolution[0] // (2 ** i_layer),                     patches_resolution[1] // (2 ** i_layer)),   depth=depths[i_layer],   num_heads=num_heads[i_layer],   window_size=window_size,   mlp_ratio=self.mlp_ratio,   qkv_bias=qkv_bias, qk_scale=qk_scale,   drop=drop_rate, attn_drop=attn_drop_rate,   drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],   norm_layer=norm_layer,   downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,   use_checkpoint=use_checkpoint)            self.layers.append(layer)        self.norm = norm_layer(self.num_features)        self.avgpool = nn.AdaptiveAvgPool1d(1)        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()        self.apply(self._init_weights)    def _init_weights(self, m):        if isinstance(m, nn.Linear):            trunc_normal_(m.weight, std=.02)            if isinstance(m, nn.Linear) and m.bias is not None:                nn.init.constant_(m.bias, 0)        elif isinstance(m, nn.LayerNorm):            nn.init.constant_(m.bias, 0)            nn.init.constant_(m.weight, 1.0)    @torch.jit.ignore    def no_weight_decay(self):        return {'absolute_pos_embed'}    @torch.jit.ignore    def no_weight_decay_keywords(self):        return {'relative_position_bias_table'}    def forward_features(self, x):        print(x.shape)   # [2, 3, 224, 224], batch_size = 2        x = self.patch_embed(x)        print(x.shape)        if self.ape:            x = x + self.absolute_pos_embed        x = self.pos_drop(x)        print(x.shape)        for layer in self.layers:            x = layer(x)            print(x.shape)        x = self.norm(x)  # B L C        print(x.shape)   # [2, 49, 768]        x = self.avgpool(x.transpose(1, 2))  # B C 1        print(x.shape)   # [2, 768, 1]        x = torch.flatten(x, 1)        print(x.shape)   # [2, 768]        return x             def forward(self, x):        x = self.forward_features(x)        x = self.head(x)  # [2,1000]做imagenet的1000分类        return x

(五)总结流程

整体流程如下

  • 先对特征图进行LayerNorm
  • 通过self.shift_size决定是否需要对特征图进行shift
  • 然后将特征图切成一个个窗口
  • 计算Attention,通过self.attn_mask来区分Window Attention还是Shift Window Attention
  • 将各个窗口合并回来
  • 如果之前有做shift操作,此时进行reverse shift,把之前的shift操作恢复.

Window Partition/Reverse
window partition函数是用于对张量划分窗口,指定窗口大小。将原本的张量从 N H W C, 划分成 num_windows × \times ×B, window_size, window_size, C,其中 num_windows = H × \times ×W / window_size,即窗口的个数。而window reverse函数则是对应的逆过程。这两个函数会在后面的Window Attention用到。

def window_partition(x, window_size):    B, H, W, C = x.shape    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)    return windowsdef window_reverse(windows, window_size, H, W):    B = int(windows.shape[0] / (H * W / window_size / window_size))    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)    return x
  • 做dropout和残差连接
  • 再通过一层LayerNorm+全连接层,以及dropout和残差连接

来源地址:https://blog.csdn.net/weixin_54546190/article/details/124422937

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

Swin-Transformer(原理 + 代码)详解

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

详解Swin Transformer核心实现,经典模型也能快速调优

Swin Transformer是一种基于Transformer结构的图像分类模型,其核心实现主要有以下几个方面:1. 分块式图片处理:Swin Transformer将输入图片分为多个非重叠的小块,每个小块称为一个局部窗格。然后通过局部窗
2023-09-20

详解B+树的原理及实现Python代码

B+树是自平衡树的高级形式,其中所有值都存在于叶级中。B+树所有叶子都处于同一水平,每个节点的子节点数量≥2。B+树与B树的区别是各节点在B树上不是相互连接,而在B+树上是相互连接的。B+树多级索引结构图B+树搜索规则1、从根节点开始
详解B+树的原理及实现Python代码
2024-01-24

MyBatisPlus代码生成器的原理及实现详解

这篇文章主要为大家详细介绍了MyBatisPlus中代码生成器的原理及实现,文中的示例代码讲解详细,对我们学习MyBatisPlus有一定帮助,需要的可以参考一下
2022-11-13

详解V8是如何执行一段JavaScript代码原理

这篇文章主要为大家介绍了详解V8是如何执行一段JavaScript代码原理详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
2023-05-15

t-SNE算法的原理和Python代码实现详解

T分布随机邻域嵌入(t-SNE),是一种用于可视化的无监督机器学习算法,使用非线性降维技术,根据数据点与特征的相似性,试图最小化高维和低维空间中这些条件概率(或相似性)之间的差异,以在低维空间中完美表示数据点。因此,t-SNE擅长在二维或三
t-SNE算法的原理和Python代码实现详解
2024-01-23

Spring静态代理和动态代理代码详解

本节要点:Java静态代理Jdk动态代理1 面向对象设计思想遇到的问题在传统OOP编程里以对象为核心,并通过对象之间的协作来形成一个完整的软件功能,由于对象可以继承,因此我们可以把具有相同功能或相同特征的属性抽象到一个层次分明的类结构体系中
2023-05-30

PHP代码转C语言:实现原理与技巧详解

PHP和C语言是两种常用的编程语言,它们在不同的领域具有各自的优势和特点。PHP是一种脚本语言,通常用于Web开发,而C语言是一种编译型语言,通常用于系统编程和嵌入式开发。在一些特定情况下,我们可能需要将PHP代码转换为C语言,以提高性能或
PHP代码转C语言:实现原理与技巧详解
2024-03-13

React Context源码实现原理详解

这篇文章主要为大家介绍了React Context源码实现原理示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
2022-11-13

Android刮刮卡实现原理与代码讲解

实现刮刮卡我们可以Get到哪些技能? * 圆形圆角图片的实现原理 * 双缓冲技术绘图 * Bitmap获取像素值数据 * 获取绘制文本的长宽 * 自定义View的掌握 * 获取屏幕密度 * TypeValue.applyDemension
2022-06-06

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录