深度学习之图像分类(三十)-- Hire-MLP网络详解

木卯 于 2021-10-20 发布

深度学习之图像分类(三十)Hire-MLP网络详解

一晃都学习了三十个网络了,时间过得真快。本次学习华为提出的 Hire-MLP,依然是通过旋转特征图,将不同位置的特征对齐到同一个通道上从而实现 MLP-based Model 中的局部感受野依赖。

img0

1. 前言

本此学习华为诺亚&北大&悉尼大学联合提出的 HireMLP。关于 MLP-Mixer 的两大主要问题我们已经在前面的学习中阐述了很多遍:对于 Token 进行全局感受野卷积,容易过拟合且对图像尺寸敏感,无法作为 Backbone 用于下游任务。Hire-MLP 也提出:如何在 MLP-based 模型中结合局部感受野和全局感受野,同时对图像输入分辨率不敏感是值得探索的地方。为了对输入图像尺寸不敏感,要不使用固定尺寸小卷积核(例如 $3 \times 3$ 的DW Conv);或者就是移动特征图,使得不同 token 对齐到同一个通道上,然后使用 $1 \times 1$ 卷积来实现局部感受野,即不同 token 之间的信息融合。Hire-MLP 也是一样的思路,通过移动特征图(本工作称之为区域重排 Region Rearrangement),从而实现局部信息的整合。本工作的原始论文为 Hire-MLP: Vision MLP via Hierarchical Rearrangement。2021.8.30 挂上 arXiv,代码并未开源。最终在 ImageNet 上达到了 83.4% 的 Top-1 精度,这与 SOTA 的 Swin Transformer 等相当。虽然 Hire-MLP 结构上对下游任务友好,但是本文并没有将其应用于下游任务,从而不清楚其真实性能。

img1

从结果可见:

2. Hire-MLP

这次讲解我先讲局部 Hire-MLP Block 结构,再描述网络的整体结构。

2.1 Hire-MLP Block

单个 Hire-MLP Block 依然是分为 Token-mixing MLP 和 Channel mixing MLP,其中作者主要的贡献点在于替换 MLP-mixer 的 Token-mixing MLP 为 Hire-Module。所以整个 Hire-MLP Block 可以描述为: \(\begin{aligned} &Y=\text { Hire-Module }(\operatorname{LN}(X))+X \\ &Z=\text { Channel-MLP }(\operatorname{LN}(Y))+Y \end{aligned}\) Channel MLP 就是最普通的两层全连接,中间使用 GELU 激活函数,第一层全连接结点个数一般为输入结点个数的 2,3 或者 4 倍。或者可以直接称之为通道方向的 $1 \times 1$ 卷积。

Hire-Module 的内部结构如下图所示,其中包含三条支路,分别是对于 H 方向的重排,W 方向的重排,以及通道方向的映射:

img4

重排分为两类:Cross-Region 以及 Inner-Region

2.1.1 Inner-Region

Height-direction 的 Inner-Region 其实就是对特征图进行 H 维度的分组,然后将其切分开之后堆叠到通道维度。这里 $H$ 为原始特征图的高度,$h$ 为每一小组内特征图的高度。即对于一个 $H \times W \times C$ 的特征图可以分成 $g = H / h$ 组,每组的特征图大小为 $h \times W \times C$,然后对特征图进行重排得到 $g \times W \times (hC)$,此后做 $hC->hC$ 的映射。映射结束后,再还原到 $H \times W \times C$ 的特征图即可。可视化流程如下所示:

img5

具体的映射其实是依赖两个全连接和一个 GELU 激活函数构成的,论文中作者将第一个全连接层的阶段设定为 $C / 2$ 用于降维和减少计算量。其实也可以只依赖两个 $1 \times 1$ 卷积加以实现。

Inner-Region 重排其实可以只需要依赖 einops 库的 Rearrange 即可方便完成:

import torch
from torch import nn
from einops.layers.torch import Rearrange, Reduce

class InnerRegionW(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.region = nn.Sequential(
            Rearrange('b c h (w group) -> b (c w) h group', w = self.w)		# 重排
        )

    def forward(self, x):
        return self.region(x)

class InnerRegionRestoreW(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.region = nn.Sequential(
            Rearrange('b (c w) h group -> b c h (w group)', w = self.w)		# 恢复
        )

    def forward(self, x):
        return self.region(x)

model1 = InnerRegionW(w = 2)
model2 = InnerRegionRestoreW(w = 2)
images = torch.randn(1, 1, 4, 4)

print(images)
print("==============================")

with torch.no_grad():
    output1 = model1(images)
    output2 = model2(output1)
print("=========== output1 ==============")
print(output1)
print("=========== output2 ==============")
print(output2)

2.1.2 Cross-Region

单个方向的 Inner-Region 产生的是特定的一维线性感受野,即将 $h \times w$ 拆分成 $1 \times w$ 和 $h \times 1$ 的形式,由于输入是 $224 \times 224$ 的方图,所以一般配置中 $h = w$。并且 Inner-Region 的空间位置相对固定,即和 Swin 分窗一样,窗口固定了,则窗口处理始终对应的都是固定的区域。为了使得不同窗内的 Token 有交互,需要移动窗,或者说需要旋转特征图。这就引发了 Cross-Region。请联系 Swin 一起思考,则容易理解多了

img6

Height-direction 的 Inner-Region 其实就是对特征图进行 H 维度的转动,转动的步长设定为 $s$ ,通常 $s$ 取 1 或者 2。Cross-Region 操作被加在 Inner-Region 前后。实际上并不是每个 Inner-Region 前后都添加 Cross-Region,因为每一个都添加一样的 Cross-Region 其实等于没有添加。所以原文中作者是隔一个添加一个,隔一个添加一个。To get a global receptive field, the cross-region rearrangement operations are inserted before the inner-region rearrangement operation every two blocks. 但是作者将这个称为全局感受野我其实是不太认同的,毕竟 $s$ 太小了,只不过边界地区挨到一起来罢了,这其实也是局部感受野而已

Cross-Region 重排其实可以只需要依赖 torch.roll 即可方便完成:

import torch
from torch import nn
from einops.layers.torch import Rearrange, Reduce

class CrossRegion(nn.Module):
    def __init__(self, step = 1, dim = 1):
        super().__init__()
        self.step = step
        self.dim = dim

    def forward(self, x):
        return torch.roll(x, self.step, self.dim)

model1 = CrossRegion(step = 2, dim = 3)
model2 = CrossRegion(step = -2, dim = 3)

model3 = CrossRegion(step = 1, dim = 2)
model4 = CrossRegion(step = -1, dim = 2)

images = torch.randn(1, 1, 5, 5)

print(images)
print("==============================")

with torch.no_grad():
    output1 = model1(images)
    output2 = model2(images)
    output3 = model3(images)
    output4 = model4(images)
print("=========== output1 ==============")
print(output1)
print("=========== output2 ==============")
print(output2)
print("=========== output3 ==============")
print(output3)
print("=========== output4 ==============")
print(output4)

2.1.3 特征融合

在 Hire-Module 中有三个并行的支路,第三条之路只需要通过一个 $C -> C$ 的单层全连接层映射。最后三条支路的特征图直接加和即可获得最终的融合特征图。其实用一下 Split-Attention 或者如 sMLPNet 说的通道拼接后经过 $1 \times 1$ 卷积降维可能还会涨点,不过计算量会有所增加。所以最终不难分析得到 Hire-Module 是一个十字形的感受野。整个 Hire-MLP Block 的推理逻辑如下所示:

def forward(self, x):
    x_h = self.inner_regionH(self.cross_regionH(x))						# 重排
    x_w = self.inner_regionW(self.cross_regionW(x))						# 重排
        
    x_h = self.proj_h(x_h)												# 映射
    x_w = self.proj_w(x_w)												# 映射
    x_c = self.proj_c(x)												# 映射

    x_h = self.cross_region_restoreH(self.inner_region_restoreH(x_h))	# 恢复
    x_w = self.cross_region_restoreW(self.inner_region_restoreW(x_w))	# 恢复

    out = x_c + x_h + x_w												# 特征融合
    return out

2.1.4 HireMLP 和 ViP,AS-MLP 的区别?

HireMLP 和 ViP 有区别吗?在我看来没有本质区别,ViP 也是分组重排操作。ViP 更复杂地使用了 Split-Attention 和残差;HireMLP 则是将中间的映射改为两层全连接。HireMLP 作者说 Inspired by the shortcut in ResNet and ViP, an extra branch without spatial communication is alse added ...,但是看起来似乎不仅仅是像他们一样添加一个 extra branch without spatial communication 这么简单,而是整个模块其实思想都是惊人的一致。至于说指标报告方面 HireMLP 高了那么 0.2%,会不会是因为两层全连接导致的?尚不可而知。不过相比 ViP 的重排,能肉眼可见的改进就是这个地方了。

img7

HireMLP 和 AS-MLP 有区别吗?在我看来没有本质区别,AS-MLP 也是十字形感受野,不过是位移的实现方式上有所不同。整体而言都是类似的。相比AS-MLP,Hire-MLP好像并没有什么优势,性能相当,速度反而AS-MLP更快。所以说:在特征图移动上玩花似乎已经走到头了,玩不出什么花来了。

img8

2.2 整体网络结构

Hire-MLP 的 Patch Embedding 很有特色,使用卷积核大小为 $7 \times 7$ ,步长为 4 的卷积。相比而言 Swin 使用卷积核大小为 $4 \times 4$,步长为 4 的卷积。在近期的我自己的小实验中也发现:Patch Embedding 时具有重叠会更好,这样可以避免边界效应并在小数据集上提升性能。Hire-MLP 中间采用多阶段金字塔模型,总共分为 4 个阶段,每个阶段交替重复使用 Hire-MLP Block。下采样使用卷积核大小为 $3 \times 3$,步长为 2 的卷积,这样做也有重叠。最后经过全局池化后连接一个全连接分类器即可,其网络结构图如下所示:

img2

作者一共提出来了四种配置:

img3

大家注意这四种配置,其中 Hire-MLP Block 中 $h,w,s$ 的取值一样的。$h,w$ 指 Inner-Region 中分小组内 H 或者 W 维度的大小,$s$ 指 Cross-Region 处旋转的步长。注意到 $H$ 不一定整除 $h$,让我们以 Base 为例,作者在训练 ImageNet 的时候以 $224 \times 224$ 的图片作为输入,经过第一个 Patch Embedding 之后,假设作者 kernel size = 7, stride = 4,假设使用了 padding = 3,则此后特征图大小为 $56 \times 56 \times 64$。此后在 Stage 1 中 $h = w = 4$,即在 Inner-Region 中可以分为 $56 / 4 = 14$ 组,特征图可以重排为 $14 \times 56 \times 256$,然后做 $256 -> 32 -> 256$ 的映射 ($hC -> C/2 -> hC$),看上去没什么问题。经过下采样的 Patch Embedding,kernel size = 3,stride = 2,假设使用了 padding = 1,则此后特征图大小为 $28 \times 28 \times 128$。这里 28 根本不是 3 的倍数,所以需要进行 padding 操作。作者通过对比发现 Circular padding 是最好的,但是 padding 类别其实对结果的影响并不大。

img9

3. 消融实验

作者一共进行了多组消融实验:

img10

img1·

img9

img12

img13

4. 总结与反思

Hire-MLP 其实依然是使用特征图移动,使得不同空间位置的 token 对齐到统一通道,然后使用通道方向的 $1 \times 1$ 卷积实现的局部依赖性引入和对图像分辨率不敏感,这种思想在过去的诸多工作例如 AS-MLP 以及 ViP,S2MLPv2 等中均可看到,所以并不是什么新鲜的贡献。此外作者说 Hire-MLP 对下游任务友好,但是并没有进行实验让人有点点失望。本工作暂未开源,所报告性能我持怀疑态度,因为作者消融实验中去除 Inner-Region 只保留 Cross-Region 其实可看作仅有 Channel 方向 $1 \times 1$ 卷积的网络,也能达到 79.81%?比 MLP-Mixer 的 76+% 还高

5. 代码

我自己实现的非官方 pytorch 代码见 此处,欢迎与大家自己复现的进行交流。

import torch
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce
from .utils import pair


class PreNormResidual(nn.Module):
    def __init__(self, dim, fn, norm = nn.LayerNorm):
        super().__init__()
        self.fn = fn
        self.norm = norm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x

class PatchEmbedding(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, stride, padding, norm_layer=False):
        super().__init__()
        self.reduction = nn.Sequential(
                            nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding = padding),

                            nn.Identity() if (not norm_layer) else nn.Sequential(
                                Rearrange('b c h w -> b h w c'),
                                nn.LayerNorm(dim_out),
                                Rearrange('b h w c -> b c h w'),
                            )
                        )

    def forward(self, x):
        return self.reduction(x)

class FeedForward(nn.Module):
    def __init__(self, dim_in, hidden_dim, dim_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, hidden_dim, kernel_size = 1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, dim_out, kernel_size = 1),
        )
    def forward(self, x):
        return self.net(x)

class CrossRegion(nn.Module):
    def __init__(self, step = 1, dim = 1):
        super().__init__()
        self.step = step
        self.dim = dim

    def forward(self, x):
        return torch.roll(x, self.step, self.dim)

class InnerRegionW(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.region = nn.Sequential(
            Rearrange('b c h (w group) -> b (c w) h group', w = self.w)
        )

    def forward(self, x):
        return self.region(x)

class InnerRegionH(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.h = h
        self.region = nn.Sequential(
            Rearrange('b c (h group) w -> b (c h) group w', h = self.h)
        )

    def forward(self, x):
        return self.region(x)

class InnerRegionRestoreW(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.region = nn.Sequential(
            Rearrange('b (c w) h group -> b c h (w group)', w = self.w)
        )

    def forward(self, x):
        return self.region(x)

class InnerRegionRestoreH(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.h = h
        self.region = nn.Sequential(
            Rearrange('b (c h) group w -> b c (h group) w', h = self.h)
        )

    def forward(self, x):
        return self.region(x)

class HireMLPBlock(nn.Module):
    def __init__(self, h, w, d_model, cross_region_step = 1, cross_region_id = 0, cross_region_interval = 2, padding_type = 'circular'):
        super().__init__()

        assert (padding_type in ['constant', 'reflect', 'replicate', 'circular'])
        self.padding_type = padding_type
        self.w = w
        self.h = h

        # cross region every cross_region_interval HireMLPBlock
        self.cross_region = (cross_region_id % cross_region_interval == 0)

        if self.cross_region:
            self.cross_regionW = CrossRegion(step = cross_region_step, dim = 3)
            self.cross_regionH = CrossRegion(step = cross_region_step, dim = 2)
            self.cross_region_restoreW = CrossRegion(step = -cross_region_step, dim = 3)
            self.cross_region_restoreH = CrossRegion(step = -cross_region_step, dim = 3)
        else:
            self.cross_regionW = nn.Identity()
            self.cross_regionH = nn.Identity()
            self.cross_region_restoreW = nn.Identity()
            self.cross_region_restoreH = nn.Identity()

        self.inner_regionW = InnerRegionW(w)
        self.inner_regionH = InnerRegionH(h)
        self.inner_region_restoreW = InnerRegionRestoreW(w)
        self.inner_region_restoreH = InnerRegionRestoreH(h)


        self.proj_h = FeedForward(h * d_model, d_model // 2, h * d_model)
        self.proj_w = FeedForward(w * d_model, d_model // 2, w * d_model)
        self.proj_c = nn.Conv2d(d_model, d_model, kernel_size = 1)
    
    def forward(self, x):
        x = x.permute(0, 3, 1, 2)

        B, C, H, W = x.shape
        padding_num_w = W % self.w
        padding_num_h = H % self.h
        x = nn.functional.pad(x, (0, self.w - padding_num_w, 0, self.h - padding_num_h), self.padding_type)

        x_h = self.inner_regionH(self.cross_regionH(x))
        x_w = self.inner_regionW(self.cross_regionW(x))
        
        x_h = self.proj_h(x_h)
        x_w = self.proj_w(x_w)
        x_c = self.proj_c(x)

        x_h = self.cross_region_restoreH(self.inner_region_restoreH(x_h))
        x_w = self.cross_region_restoreW(self.inner_region_restoreW(x_w))

        out = x_c + x_h + x_w

        out = out[:,:,0:H,0:W]
        out = out.permute(0, 2, 3, 1)
        return out

class HireMLPStage(nn.Module):
    def __init__(self, h, w, d_model_in, d_model_out, depth, cross_region_step, cross_region_interval, expansion_factor = 2, dropout = 0., pooling = False, padding_type = 'circular'):
        super().__init__()

        self.pooling = pooling
        self.patch_merge = nn.Sequential(
            Rearrange('b h w c -> b c h w'),
            PatchEmbedding(d_model_in, d_model_out, kernel_size = 3, stride = 2, padding=1, norm_layer=False),
            Rearrange('b c h w -> b h w c'),
        )

        self.model = nn.Sequential(
            *[nn.Sequential(
                PreNormResidual(d_model_in, nn.Sequential(
                    HireMLPBlock(
                        h, w, d_model_in, cross_region_step = cross_region_step, cross_region_id = i_depth + 1, cross_region_interval = cross_region_interval, padding_type = padding_type
                    )
                ), norm = nn.LayerNorm),
                PreNormResidual(d_model_in, nn.Sequential(
                    nn.Linear(d_model_in, d_model_in * expansion_factor),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_model_in * expansion_factor, d_model_in),
                    nn.Dropout(dropout),
                ), norm = nn.LayerNorm),
            ) for i_depth in range(depth)]
        )

    def forward(self, x):
        x = self.model(x)
        if self.pooling:
            x = self.patch_merge(x)
        return x

class HireMLP(nn.Module):
    def __init__(
        self,
        patch_size=4,
        in_channels=3,
        num_classes=1000,
        d_model=[64, 128, 320, 512],
        h = [4,3,3,2],
        w = [4,3,3,2],
        cross_region_step = [2,2,1,1],
        cross_region_interval = 2,
        depth=[4,6,24,3],
        expansion_factor = 2,
        patcher_norm = False,
        padding_type = 'circular',
    ):
        patch_size = pair(patch_size)
        super().__init__()
        self.patcher = PatchEmbedding(dim_in = in_channels, dim_out = d_model[0], kernel_size = 7, stride = patch_size, padding = 3, norm_layer=patcher_norm)
        

        self.layers = nn.ModuleList()
        for i_layer in range(len(depth)):
            i_depth = depth[i_layer]
            i_stage = HireMLPStage(h[i_layer], w[i_layer], d_model[i_layer], d_model_out = d_model[i_layer + 1] if (i_layer + 1 < len(depth)) else d_model[-1],
                depth = i_depth, cross_region_step = cross_region_step[i_layer], cross_region_interval = cross_region_interval,
                expansion_factor = expansion_factor, pooling = ((i_layer + 1) < len(depth)), padding_type = padding_type)
            self.layers.append(i_stage)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(d_model[-1]),
            Reduce('b h w c -> b c', 'mean'),
            nn.Linear(d_model[-1], num_classes)
        )

    def forward(self, x):
        embedding = self.patcher(x)
        embedding = embedding.permute(0, 2, 3, 1)
        for layer in self.layers:
            embedding = layer(embedding)
        out = self.mlp_head(embedding)
        return out