044、Swin Transformer V2 Block:窗口注意力和移位窗口注意力的视觉 Transformer 模块
2026/6/7 12:12:58 网站建设 项目流程

044、Swin Transformer V2 Block:窗口注意力和移位窗口注意力的视觉 Transformer 模块

从一次诡异的显存爆炸说起

去年年底调一个检测模型,backbone换成Swin-Tiny,训练到第3个epoch突然OOM。检查了batch size、输入尺寸、梯度累积,都没问题。最后定位到是Swin Block里的窗口注意力实现——我手写了一个naive版本,没做窗口划分的mask优化,导致每个窗口的注意力矩阵计算量爆炸。那次debug让我彻底把Swin V2的源码啃了一遍,今天把Block级别的实现拆开讲清楚。

Swin Transformer V2 Block 整体结构

先看forward流程,别急着背公式,跟着代码走一遍:

classSwinTransformerBlockV2(nn.Module):def__init__(self,dim,input_resolution,num_heads,window_size=7,shift_size=0,...):super().__init__()self.dim=dim self.input_resolution=input_resolution# (H, W)self.num_heads=num_heads self.window_size=window_size self.shift_size=shift_size# 这里踩过坑:window_size必须是奇数,否则移位后窗口对齐会出问题# 官方默认7,别手贱改成偶数assert0<=self.shift_size<self.window_size,"shift_size必须在[0, window_size)区间"# LN + 窗口多头注意力 + LN + MLP,标准的Transformer Block结构self.norm1=nn.LayerNorm(dim)self.attn=WindowAttentionV2(dim,window_size=to_2tuple(self.window_size),num_heads=num_heads,qkv_bias=qkv_bias,attn_drop=attn_drop,proj_drop=proj_drop,pretrained_window_size=to_2tuple(pretrained_window_size))self.norm2=nn.LayerNorm(dim)self.mlp=Mlp(in_features=dim,hidden_features=int(dim*mlp_ratio),act_layer=act_layer,drop=drop)# 注意:V2版本在LN之前加了残差连接,和V1不同# 别这样写:self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()# 官方实现直接硬编码,避免if判断影响速度self.drop_path=DropPath(drop_path)ifdrop_path>0.elsenn.Identity()

forward函数才是灵魂:

defforward(self,x):H,W=self.input_resolution B,L,C=x.shape# L = H * W# 先做shortcut,这里有个细节:V2的残差在LN之前shortcut=x x=self.norm1(x)x=x.view(B,H,W,C)# 恢复成2D空间结构# 循环移位——这是Swin的核心创新点# 别这样写:if self.shift_size > 0: x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1,2))# 官方实现用shift_size控制是否移位,但注意方向:向左上移位ifself.shift_size>0:shifted_x=torch.roll(x,shifts=(-self.shift_size,-self.shift_size),dims=(1,2))# 生成attention mask,用于屏蔽跨窗口的无效注意力attn_mask=self.get_attn_mask(H,W)else:shifted_x=x attn_mask=None# 将特征图划分成窗口# 这里踩过坑:窗口划分必须保证H和W能被window_size整除# 如果不行,需要做padding,但Swin V2默认输入尺寸是7的倍数windows=self.window_partition(shifted_x,self.window_size)# nW*B, window_size, window_size, Cwindows=windows.view(-1,self.window_size*self.window_size,C)# nW*B, window_size^2, C# 窗口注意力计算attn_windows=self.attn(windows,mask=attn_mask)# nW*B, window_size^2, C# 还原窗口为特征图attn_windows=attn_windows.view(-1,self.window_size,self.window_size,C)shifted_x=self.window_reverse(attn_windows,self.window_size,H,W)# B, H, W, C# 反向移位,恢复原始位置ifself.shift_size>0:x=torch.roll(shifted_x,shifts=(self.shift_size,self.shift_size),dims=(1,2))else:x=shifted_x x=x.view(B,H*W,C)# 残差连接 + MLPx=shortcut+self.drop_path(x)x=x+self.drop_path(self.mlp(self.norm2(x)))returnx

窗口划分与还原:别小看这两个函数

窗口划分是Swin的基石,但实现细节容易翻车:

defwindow_partition(self,x,window_size):""" x: B, H, W, C return: num_windows*B, window_size, window_size, C """B,H,W,C=x.shape# 这里踩过坑:view的顺序不能错,先拆H再拆Wx=x.view(B,H//window_size,window_size,W//window_size,window_size,C)# permute要小心:把窗口维度合并到batch维度windows=x.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)returnwindows

窗口还原是逆操作:

defwindow_reverse(self,windows,window_size,H,W):""" windows: num_windows*B, window_size, window_size, C return: B, H, W, C """B=int(windows.shape[0]/(H//window_size*W//window_size))x=windows.view(B,H//window_size,W//window_size,window_size,window_size,-1)x=x.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)returnx

别这样写:用reshape代替view。reshape虽然功能一样,但view能触发连续性检查,提前发现内存布局问题。我吃过亏,reshape不报错但结果全错。

注意力掩码:Swin V2的隐藏功臣

移位窗口后,不同窗口之间需要屏蔽注意力。这个mask的生成逻辑是Swin最难理解的部分:

defget_attn_mask(self,H,W):# 生成每个像素的窗口索引img_mask=torch.zeros((1,H,W,1))# 1, H, W, 1h_slices=(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))w_slices=(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))cnt=0forhinh_slices:forwinw_slices:img_mask[:,h,w,:]=cnt cnt+=1# 划分成窗口mask_windows=self.window_partition(img_mask,self.window_size)# nW, window_size, window_size, 1mask_windows=mask_windows.view(-1,self.window_size*self.window_size)# 计算每个窗口内像素的索引差异# 这里踩过坑:用unsqueeze扩展维度,别用viewattn_mask=mask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)# 非0位置表示来自不同窗口,需要屏蔽attn_mask=attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))returnattn_mask

这个mask的逻辑:移位后,原本连续的窗口被切分成9个区域(3x3网格),每个区域标记不同的索引。在注意力计算时,相同索引的像素可以互相注意,不同索引的像素被-100屏蔽。

WindowAttentionV2:V2版本的改进点

V2版本的窗口注意力相比V1有几个关键改动:

classWindowAttentionV2(nn.Module):def__init__(self,dim,window_size,num_heads,qkv_bias=True,attn_drop=0.,proj_drop=0.,pretrained_window_size=[0,0]):super().__init__()self.dim=dim self.window_size=window_size# Wh, Wwself.num_heads=num_heads self.logit_scale=nn.Parameter(torch.log(10*torch.ones((num_heads,1,1))))# V2新增:相对位置偏置的预训练窗口尺寸self.cpb_mlp=nn.Sequential(nn.Linear(2,512,bias=True),nn.ReLU(inplace=True),nn.Linear(512,num_heads,bias=False))# 别这样写:直接硬编码相对位置索引# 应该用register_buffer,避免被当成模型参数relative_coords_h=torch.arange(-(self.window_size[0]-1),self.window_size[0],dtype=torch.float32)relative_coords_w=torch.arange(-(self.window_size[1]-1),self.window_size[1],dtype=torch.float32)relative_coords_table=torch.stack(torch.meshgrid([relative_coords_h,relative_coords_w])).permute(1,2,0).contiguous().unsqueeze(0)# 归一化到[-1, 1]relative_coords_table[:,:,:,0]/=(self.window_size[0]-1)relative_coords_table[:,:,:,1]/=(self.window_size[1]-1)self.register_buffer("relative_coords_table",relative_coords_table)# QKV投影self.qkv=nn.Linear(dim,dim*3,bias=False)ifqkv_bias:self.q_bias=nn.Parameter(torch.zeros(dim))self.v_bias=nn.Parameter(torch.zeros(dim))else:self.q_bias=Noneself.v_bias=Noneself.attn_drop=nn.Dropout(attn_drop)self.proj=nn.Linear(dim,dim)self.proj_drop=nn.Dropout(proj_drop)self.softmax=nn.Softmax(dim=-1)

V2的核心改进在注意力计算:

defforward(self,x,mask=None):B_,N,C=x.shape# QKV分离,注意V2的bias处理方式qkv_bias=Noneifself.q_biasisnotNone:qkv_bias=torch.cat((self.q_bias,torch.zeros_like(self.v_bias,requires_grad=False),self.v_bias))qkv=F.linear(input=x,weight=self.qkv.weight,bias=qkv_bias)qkv=qkv.reshape(B_,N,3,self.num_heads,-1).permute(2,0,3,1,4)q,k,v=qkv[0],qkv[1],qkv[2]# 别这样写:q, k, v = qkv.unbind(0)# V2的亮点:使用logit_scale代替固定的温度系数# 这里踩过坑:logit_scale需要clamp,防止梯度爆炸attn=(F.normalize(q,dim=-1)@ F.normalize(k,dim=-1).transpose(-2,-1))logit_scale=torch.clamp(self.logit_scale,max=torch.log(torch.tensor(1./0.01))).exp()attn=attn*logit_scale# 相对位置偏置,V2用MLP生成relative_position_bias=self.cpb_mlp(self.relative_coords_table).view(-1,self.num_heads)relative_position_bias=relative_position_bias.permute(1,0).contiguous().view(1,self.num_heads,self.window_size[0]*self.window_size[1],self.window_size[0]*self.window_size[1])attn=attn+relative_position_bias# 应用maskifmaskisnotNone:nW=mask.shape[0]attn=attn.view(B_//nW,nW,self.num_heads,N,N)+mask.unsqueeze(1).unsqueeze(0)attn=attn.view(-1,self.num_heads,N,N)attn=self.softmax(attn)else:attn=self.softmax(attn)attn=self.attn_drop(attn)x=(attn @ v).transpose(1,2).reshape(B_,N,C)x=self.proj(x)x=self.proj_drop(x)returnx

实战经验:Swin V2 Block的调参陷阱

  1. 窗口大小选择:7是黄金值,别问我为什么。试过8、9、11,要么显存爆炸要么精度下降。如果输入分辨率不是7的倍数,记得做padding,但Swin V2官方实现已经处理了。

  2. 移位策略:连续两个Swin Block,第一个shift_size=0,第二个shift_size=window_size//2。这个交替模式不能改,改了感受野会出问题。

  3. 显存优化:如果显存不够,可以尝试减少num_heads或者降低mlp_ratio。但别动window_size,那是Swin的命根子。

  4. 训练稳定性:V2的logit_scale初始化为log(10),但训练过程中可能变得很大。建议加个clamp,max=log(100)左右,不然注意力分布会过于尖锐。

  5. 混合精度训练:Swin V2在fp16下容易梯度爆炸,建议用amp的GradScaler,或者在attention计算时强制用fp32。

个人经验总结

Swin V2 Block相比V1,最大的改进是去掉了LayerNorm的位置依赖,用logit_scale和CPB-MLP替代了固定的温度系数和相对位置偏置。这些改动让模型对窗口大小和输入分辨率更鲁棒。

但说实话,如果你只是做目标检测的backbone,V1和V2的差异没那么大。我自己的实验里,V2在COCO上比V1高0.3-0.5个点,但训练时间多了15%。如果追求效率,V1完全够用。

最后提醒一句:别自己手写Swin Block,直接用timm或者官方实现。我那次OOM就是血的教训——手写版本少了mask优化,窗口划分后直接做全注意力,显存直接翻倍。有些坑,踩过一次就够了。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询