深度学习之图像分类(三十)Hire-MLP网络详解
一晃都学习了三十个网络了,时间过得真快。本次学习华为提出的 Hire-MLP,依然是通过旋转特征图,将不同位置的特征对齐到同一个通道上从而实现 MLP-based Model 中的局部感受野依赖。

1. 前言
本此学习华为诺亚&北大&悉尼大学联合提出的 HireMLP。关于 MLP-Mixer 的两大主要问题我们已经在前面的学习中阐述了很多遍:对于 Token 进行全局感受野卷积,容易过拟合且对图像尺寸敏感,无法作为 Backbone 用于下游任务。Hire-MLP 也提出:如何在 MLP-based 模型中结合局部感受野和全局感受野,同时对图像输入分辨率不敏感是值得探索的地方。为了对输入图像尺寸不敏感,要不使用固定尺寸小卷积核(例如 $3 \times 3$ 的DW Conv);或者就是移动特征图,使得不同 token 对齐到同一个通道上,然后使用 $1 \times 1$ 卷积来实现局部感受野,即不同 token 之间的信息融合。Hire-MLP 也是一样的思路,通过移动特征图(本工作称之为区域重排 Region Rearrangement),从而实现局部信息的整合。本工作的原始论文为 Hire-MLP: Vision MLP via Hierarchical Rearrangement。2021.8.30 挂上 arXiv,代码并未开源。最终在 ImageNet 上达到了 83.4% 的 Top-1 精度,这与 SOTA 的 Swin Transformer 等相当。虽然 Hire-MLP 结构上对下游任务友好,但是本文并没有将其应用于下游任务,从而不清楚其真实性能。

从结果可见:
- Hire-MLP-S 取得了 81.8% 的精度,而计算量仅为 4.2G Flops,优于其他 MLP 方案。相比 AS-MLP、CycleMLP,所提 Hire-MLP 性能更佳。
- Hire-MLP-B 与 Hire-MLP-L 分别取得了 83.1% 与 83.4% 的精度,而计算量分别为 8.1G 与 13.5G。
- 相比 DeiT、Swin 以及 PVT,所提方案具有更快的推理速度;
- 相比 RegNetY,所提方案具有更高的精度,同时具有相似的模型大小和复杂度。
- 相比AS-MLP,Hire-MLP好像并没有什么优势,性能相当,速度反而AS-MLP更快 。(AS-MLP 源代码的 Shift 操作是用 cupy 库自己编程实现的,可能会有影响)
2. Hire-MLP
这次讲解我先讲局部 Hire-MLP Block 结构,再描述网络的整体结构。
2.1 Hire-MLP Block
单个 Hire-MLP Block 依然是分为 Token-mixing MLP 和 Channel mixing MLP,其中作者主要的贡献点在于替换 MLP-mixer 的 Token-mixing MLP 为 Hire-Module。所以整个 Hire-MLP Block 可以描述为: \(\begin{aligned} &Y=\text { Hire-Module }(\operatorname{LN}(X))+X \\ &Z=\text { Channel-MLP }(\operatorname{LN}(Y))+Y \end{aligned}\) Channel MLP 就是最普通的两层全连接,中间使用 GELU 激活函数,第一层全连接结点个数一般为输入结点个数的 2,3 或者 4 倍。或者可以直接称之为通道方向的 $1 \times 1$ 卷积。
Hire-Module 的内部结构如下图所示,其中包含三条支路,分别是对于 H 方向的重排,W 方向的重排,以及通道方向的映射:

重排分为两类:Cross-Region 以及 Inner-Region。
2.1.1 Inner-Region
Height-direction 的 Inner-Region 其实就是对特征图进行 H 维度的分组,然后将其切分开之后堆叠到通道维度。这里 $H$ 为原始特征图的高度,$h$ 为每一小组内特征图的高度。即对于一个 $H \times W \times C$ 的特征图可以分成 $g = H / h$ 组,每组的特征图大小为 $h \times W \times C$,然后对特征图进行重排得到 $g \times W \times (hC)$,此后做 $hC->hC$ 的映射。映射结束后,再还原到 $H \times W \times C$ 的特征图即可。可视化流程如下所示:

具体的映射其实是依赖两个全连接和一个 GELU 激活函数构成的,论文中作者将第一个全连接层的阶段设定为 $C / 2$ 用于降维和减少计算量。其实也可以只依赖两个 $1 \times 1$ 卷积加以实现。
Inner-Region 重排其实可以只需要依赖 einops 库的 Rearrange 即可方便完成:
import torch
from torch import nn
from einops.layers.torch import Rearrange, Reduce
class InnerRegionW(nn.Module):
def __init__(self, w):
super().__init__()
self.w = w
self.region = nn.Sequential(
Rearrange('b c h (w group) -> b (c w) h group', w = self.w) # 重排
)
def forward(self, x):
return self.region(x)
class InnerRegionRestoreW(nn.Module):
def __init__(self, w):
super().__init__()
self.w = w
self.region = nn.Sequential(
Rearrange('b (c w) h group -> b c h (w group)', w = self.w) # 恢复
)
def forward(self, x):
return self.region(x)
model1 = InnerRegionW(w = 2)
model2 = InnerRegionRestoreW(w = 2)
images = torch.randn(1, 1, 4, 4)
print(images)
print("==============================")
with torch.no_grad():
output1 = model1(images)
output2 = model2(output1)
print("=========== output1 ==============")
print(output1)
print("=========== output2 ==============")
print(output2)
2.1.2 Cross-Region
单个方向的 Inner-Region 产生的是特定的一维线性感受野,即将 $h \times w$ 拆分成 $1 \times w$ 和 $h \times 1$ 的形式,由于输入是 $224 \times 224$ 的方图,所以一般配置中 $h = w$。并且 Inner-Region 的空间位置相对固定,即和 Swin 分窗一样,窗口固定了,则窗口处理始终对应的都是固定的区域。为了使得不同窗内的 Token 有交互,需要移动窗,或者说需要旋转特征图。这就引发了 Cross-Region。请联系 Swin 一起思考,则容易理解多了。

Height-direction 的 Inner-Region 其实就是对特征图进行 H 维度的转动,转动的步长设定为 $s$ ,通常 $s$ 取 1 或者 2。Cross-Region 操作被加在 Inner-Region 前后。实际上并不是每个 Inner-Region 前后都添加 Cross-Region,因为每一个都添加一样的 Cross-Region 其实等于没有添加。所以原文中作者是隔一个添加一个,隔一个添加一个。To get a global receptive field, the cross-region rearrangement operations are inserted before the inner-region rearrangement operation every two blocks. 但是作者将这个称为全局感受野我其实是不太认同的,毕竟 $s$ 太小了,只不过边界地区挨到一起来罢了,这其实也是局部感受野而已…
Cross-Region 重排其实可以只需要依赖 torch.roll 即可方便完成:
import torch
from torch import nn
from einops.layers.torch import Rearrange, Reduce
class CrossRegion(nn.Module):
def __init__(self, step = 1, dim = 1):
super().__init__()
self.step = step
self.dim = dim
def forward(self, x):
return torch.roll(x, self.step, self.dim)
model1 = CrossRegion(step = 2, dim = 3)
model2 = CrossRegion(step = -2, dim = 3)
model3 = CrossRegion(step = 1, dim = 2)
model4 = CrossRegion(step = -1, dim = 2)
images = torch.randn(1, 1, 5, 5)
print(images)
print("==============================")
with torch.no_grad():
output1 = model1(images)
output2 = model2(images)
output3 = model3(images)
output4 = model4(images)
print("=========== output1 ==============")
print(output1)
print("=========== output2 ==============")
print(output2)
print("=========== output3 ==============")
print(output3)
print("=========== output4 ==============")
print(output4)
2.1.3 特征融合
在 Hire-Module 中有三个并行的支路,第三条之路只需要通过一个 $C -> C$ 的单层全连接层映射。最后三条支路的特征图直接加和即可获得最终的融合特征图。其实用一下 Split-Attention 或者如 sMLPNet 说的通道拼接后经过 $1 \times 1$ 卷积降维可能还会涨点,不过计算量会有所增加。所以最终不难分析得到 Hire-Module 是一个十字形的感受野。整个 Hire-MLP Block 的推理逻辑如下所示:
def forward(self, x):
x_h = self.inner_regionH(self.cross_regionH(x)) # 重排
x_w = self.inner_regionW(self.cross_regionW(x)) # 重排
x_h = self.proj_h(x_h) # 映射
x_w = self.proj_w(x_w) # 映射
x_c = self.proj_c(x) # 映射
x_h = self.cross_region_restoreH(self.inner_region_restoreH(x_h)) # 恢复
x_w = self.cross_region_restoreW(self.inner_region_restoreW(x_w)) # 恢复
out = x_c + x_h + x_w # 特征融合
return out
2.1.4 HireMLP 和 ViP,AS-MLP 的区别?
HireMLP 和 ViP 有区别吗?在我看来没有本质区别,ViP 也是分组重排操作。ViP 更复杂地使用了 Split-Attention 和残差;HireMLP 则是将中间的映射改为两层全连接。HireMLP 作者说 Inspired by the shortcut in ResNet and ViP, an extra branch without spatial communication is alse added ...,但是看起来似乎不仅仅是像他们一样添加一个 extra branch without spatial communication 这么简单,而是整个模块其实思想都是惊人的一致。至于说指标报告方面 HireMLP 高了那么 0.2%,会不会是因为两层全连接导致的?尚不可而知。不过相比 ViP 的重排,能肉眼可见的改进就是这个地方了。

HireMLP 和 AS-MLP 有区别吗?在我看来没有本质区别,AS-MLP 也是十字形感受野,不过是位移的实现方式上有所不同。整体而言都是类似的。相比AS-MLP,Hire-MLP好像并没有什么优势,性能相当,速度反而AS-MLP更快。所以说:在特征图移动上玩花似乎已经走到头了,玩不出什么花来了。

2.2 整体网络结构
Hire-MLP 的 Patch Embedding 很有特色,使用卷积核大小为 $7 \times 7$ ,步长为 4 的卷积。相比而言 Swin 使用卷积核大小为 $4 \times 4$,步长为 4 的卷积。在近期的我自己的小实验中也发现:Patch Embedding 时具有重叠会更好,这样可以避免边界效应并在小数据集上提升性能。Hire-MLP 中间采用多阶段金字塔模型,总共分为 4 个阶段,每个阶段交替重复使用 Hire-MLP Block。下采样使用卷积核大小为 $3 \times 3$,步长为 2 的卷积,这样做也有重叠。最后经过全局池化后连接一个全连接分类器即可,其网络结构图如下所示:

作者一共提出来了四种配置:

大家注意这四种配置,其中 Hire-MLP Block 中 $h,w,s$ 的取值一样的。$h,w$ 指 Inner-Region 中分小组内 H 或者 W 维度的大小,$s$ 指 Cross-Region 处旋转的步长。注意到 $H$ 不一定整除 $h$,让我们以 Base 为例,作者在训练 ImageNet 的时候以 $224 \times 224$ 的图片作为输入,经过第一个 Patch Embedding 之后,假设作者 kernel size = 7, stride = 4,假设使用了 padding = 3,则此后特征图大小为 $56 \times 56 \times 64$。此后在 Stage 1 中 $h = w = 4$,即在 Inner-Region 中可以分为 $56 / 4 = 14$ 组,特征图可以重排为 $14 \times 56 \times 256$,然后做 $256 -> 32 -> 256$ 的映射 ($hC -> C/2 -> hC$),看上去没什么问题。经过下采样的 Patch Embedding,kernel size = 3,stride = 2,假设使用了 padding = 1,则此后特征图大小为 $28 \times 28 \times 128$。这里 28 根本不是 3 的倍数,所以需要进行 padding 操作。作者通过对比发现 Circular padding 是最好的,但是 padding 类别其实对结果的影响并不大。

3. 消融实验
作者一共进行了多组消融实验:
- Inner-Region 中 $h, w$ 的影响:作者发现浅层使用大一点点的 $h,w$,深层使用小一点的 $h, w$ 效果更好,最终使用 $h = w = [4,3,3,2]$

- Cross-Region 中 $s$ 的影响:作者发现浅层使用大一点点的 $s$,深层使用小一点的 $s$ 效果更好,最终使用 $s = [2,2,1,1]$,所以其实 Cross-Region 并没有实现大感受野信息交换…

- 不同 padding 策略的影响:最终发现padding 类别其实对结果的影响并不大,Circular padding 效果略好。

- Hire Module 中不同模块的作用效果:可见去除 Inner-Region 或者 Cross-Region 性能都会下降,其中 Inner-Region 有关局部感受野,去除了只剩下 Cross-Region 其实等价于朴素的只有通道方向的映射,所以性能影响更大。但是只有通道方向的映射,真的作者能做到 79.81% 吗?我持保留意见!

- Cross-Region 的 shift 方式:作者比较了直接旋转的 Shifted manner 与 ShuffleNet 那样的分组的方式,最终发现 Shifted manner 会更好,这也符合直觉,因为图像中随机膈几列选一列的做法其实局部性并不保证了。可见局部性还是比较重要的。但是真的性能只差一丝丝吗?

4. 总结与反思
Hire-MLP 其实依然是使用特征图移动,使得不同空间位置的 token 对齐到统一通道,然后使用通道方向的 $1 \times 1$ 卷积实现的局部依赖性引入和对图像分辨率不敏感,这种思想在过去的诸多工作例如 AS-MLP 以及 ViP,S2MLPv2 等中均可看到,所以并不是什么新鲜的贡献。此外作者说 Hire-MLP 对下游任务友好,但是并没有进行实验让人有点点失望。本工作暂未开源,所报告性能我持怀疑态度,因为作者消融实验中去除 Inner-Region 只保留 Cross-Region 其实可看作仅有 Channel 方向 $1 \times 1$ 卷积的网络,也能达到 79.81%?比 MLP-Mixer 的 76+% 还高?
5. 代码
我自己实现的非官方 pytorch 代码见 此处,欢迎与大家自己复现的进行交流。
import torch
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce
from .utils import pair
class PreNormResidual(nn.Module):
def __init__(self, dim, fn, norm = nn.LayerNorm):
super().__init__()
self.fn = fn
self.norm = norm(dim)
def forward(self, x):
return self.fn(self.norm(x)) + x
class PatchEmbedding(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, stride, padding, norm_layer=False):
super().__init__()
self.reduction = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding = padding),
nn.Identity() if (not norm_layer) else nn.Sequential(
Rearrange('b c h w -> b h w c'),
nn.LayerNorm(dim_out),
Rearrange('b h w c -> b c h w'),
)
)
def forward(self, x):
return self.reduction(x)
class FeedForward(nn.Module):
def __init__(self, dim_in, hidden_dim, dim_out):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, hidden_dim, kernel_size = 1),
nn.GELU(),
nn.Conv2d(hidden_dim, dim_out, kernel_size = 1),
)
def forward(self, x):
return self.net(x)
class CrossRegion(nn.Module):
def __init__(self, step = 1, dim = 1):
super().__init__()
self.step = step
self.dim = dim
def forward(self, x):
return torch.roll(x, self.step, self.dim)
class InnerRegionW(nn.Module):
def __init__(self, w):
super().__init__()
self.w = w
self.region = nn.Sequential(
Rearrange('b c h (w group) -> b (c w) h group', w = self.w)
)
def forward(self, x):
return self.region(x)
class InnerRegionH(nn.Module):
def __init__(self, h):
super().__init__()
self.h = h
self.region = nn.Sequential(
Rearrange('b c (h group) w -> b (c h) group w', h = self.h)
)
def forward(self, x):
return self.region(x)
class InnerRegionRestoreW(nn.Module):
def __init__(self, w):
super().__init__()
self.w = w
self.region = nn.Sequential(
Rearrange('b (c w) h group -> b c h (w group)', w = self.w)
)
def forward(self, x):
return self.region(x)
class InnerRegionRestoreH(nn.Module):
def __init__(self, h):
super().__init__()
self.h = h
self.region = nn.Sequential(
Rearrange('b (c h) group w -> b c (h group) w', h = self.h)
)
def forward(self, x):
return self.region(x)
class HireMLPBlock(nn.Module):
def __init__(self, h, w, d_model, cross_region_step = 1, cross_region_id = 0, cross_region_interval = 2, padding_type = 'circular'):
super().__init__()
assert (padding_type in ['constant', 'reflect', 'replicate', 'circular'])
self.padding_type = padding_type
self.w = w
self.h = h
# cross region every cross_region_interval HireMLPBlock
self.cross_region = (cross_region_id % cross_region_interval == 0)
if self.cross_region:
self.cross_regionW = CrossRegion(step = cross_region_step, dim = 3)
self.cross_regionH = CrossRegion(step = cross_region_step, dim = 2)
self.cross_region_restoreW = CrossRegion(step = -cross_region_step, dim = 3)
self.cross_region_restoreH = CrossRegion(step = -cross_region_step, dim = 3)
else:
self.cross_regionW = nn.Identity()
self.cross_regionH = nn.Identity()
self.cross_region_restoreW = nn.Identity()
self.cross_region_restoreH = nn.Identity()
self.inner_regionW = InnerRegionW(w)
self.inner_regionH = InnerRegionH(h)
self.inner_region_restoreW = InnerRegionRestoreW(w)
self.inner_region_restoreH = InnerRegionRestoreH(h)
self.proj_h = FeedForward(h * d_model, d_model // 2, h * d_model)
self.proj_w = FeedForward(w * d_model, d_model // 2, w * d_model)
self.proj_c = nn.Conv2d(d_model, d_model, kernel_size = 1)
def forward(self, x):
x = x.permute(0, 3, 1, 2)
B, C, H, W = x.shape
padding_num_w = W % self.w
padding_num_h = H % self.h
x = nn.functional.pad(x, (0, self.w - padding_num_w, 0, self.h - padding_num_h), self.padding_type)
x_h = self.inner_regionH(self.cross_regionH(x))
x_w = self.inner_regionW(self.cross_regionW(x))
x_h = self.proj_h(x_h)
x_w = self.proj_w(x_w)
x_c = self.proj_c(x)
x_h = self.cross_region_restoreH(self.inner_region_restoreH(x_h))
x_w = self.cross_region_restoreW(self.inner_region_restoreW(x_w))
out = x_c + x_h + x_w
out = out[:,:,0:H,0:W]
out = out.permute(0, 2, 3, 1)
return out
class HireMLPStage(nn.Module):
def __init__(self, h, w, d_model_in, d_model_out, depth, cross_region_step, cross_region_interval, expansion_factor = 2, dropout = 0., pooling = False, padding_type = 'circular'):
super().__init__()
self.pooling = pooling
self.patch_merge = nn.Sequential(
Rearrange('b h w c -> b c h w'),
PatchEmbedding(d_model_in, d_model_out, kernel_size = 3, stride = 2, padding=1, norm_layer=False),
Rearrange('b c h w -> b h w c'),
)
self.model = nn.Sequential(
*[nn.Sequential(
PreNormResidual(d_model_in, nn.Sequential(
HireMLPBlock(
h, w, d_model_in, cross_region_step = cross_region_step, cross_region_id = i_depth + 1, cross_region_interval = cross_region_interval, padding_type = padding_type
)
), norm = nn.LayerNorm),
PreNormResidual(d_model_in, nn.Sequential(
nn.Linear(d_model_in, d_model_in * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model_in * expansion_factor, d_model_in),
nn.Dropout(dropout),
), norm = nn.LayerNorm),
) for i_depth in range(depth)]
)
def forward(self, x):
x = self.model(x)
if self.pooling:
x = self.patch_merge(x)
return x
class HireMLP(nn.Module):
def __init__(
self,
patch_size=4,
in_channels=3,
num_classes=1000,
d_model=[64, 128, 320, 512],
h = [4,3,3,2],
w = [4,3,3,2],
cross_region_step = [2,2,1,1],
cross_region_interval = 2,
depth=[4,6,24,3],
expansion_factor = 2,
patcher_norm = False,
padding_type = 'circular',
):
patch_size = pair(patch_size)
super().__init__()
self.patcher = PatchEmbedding(dim_in = in_channels, dim_out = d_model[0], kernel_size = 7, stride = patch_size, padding = 3, norm_layer=patcher_norm)
self.layers = nn.ModuleList()
for i_layer in range(len(depth)):
i_depth = depth[i_layer]
i_stage = HireMLPStage(h[i_layer], w[i_layer], d_model[i_layer], d_model_out = d_model[i_layer + 1] if (i_layer + 1 < len(depth)) else d_model[-1],
depth = i_depth, cross_region_step = cross_region_step[i_layer], cross_region_interval = cross_region_interval,
expansion_factor = expansion_factor, pooling = ((i_layer + 1) < len(depth)), padding_type = padding_type)
self.layers.append(i_stage)
self.mlp_head = nn.Sequential(
nn.LayerNorm(d_model[-1]),
Reduce('b h w c -> b c', 'mean'),
nn.Linear(d_model[-1], num_classes)
)
def forward(self, x):
embedding = self.patcher(x)
embedding = embedding.permute(0, 2, 3, 1)
for layer in self.layers:
embedding = layer(embedding)
out = self.mlp_head(embedding)
return out