论文复现:FNet——使用傅里叶变换替代自注意力层_副本
论文复现赛第五期:FNet——使用傅里叶变换替代自注意力层以实现token信息混合。承载论文:FNet: Mixing Tokens with Fourier Transforms
转载自AI Studio
标题项目链接https://aistudio.baidu.com/aistudio/projectdetail/3158413
基于PaddlePaddle复现FNet
论文:FNet: Mixing Tokens with Fourier Transforms
复现具体流程:https://github.com/HJHGJGHHG/Paddle-FNet
一、Motivation
经典 Transformer 中 Self-Attention O ( n 2 ) O(n^2) O(n2) 的时空复杂度是一大缺陷,而我们知道 Self-Attention 的本质在于融合各个 token 间的信息,那么有没有别的更加高效的方式也能完成这一目标呢?显然存在,诸如 MLP-Mixer 与 Synthesizer 等等。那么这篇文章给出了一个新思路:使用 傅里叶变换 完成这个任务。
We show that Transformer encoder architectures can be sped up, with limited accuracy costs, by replacing the self-attention sublayers with simple linear transformations that “mix” input tokens.
…replacing the self-attention sublayer in a Transformer encoder with a standard, unparameterized Fourier Transform achieves 92-97% of the accuracy of BERT counterparts on the GLUE benchmark, but trains 80% faster on GPUs and 70% faster on TPUs at standard 512 input lengths.
二、Fourier Transform
傅里叶变换可以把原来的信号变换到 频域 中表示,这样可以让很多在 时域 中难以捕捉的特征无处藏身。
模型架构:
很显然,相较于经典 Transformer,这里只是使用傅里叶层替换了自注意力层。傅里叶层计算公式如下:
y = R ( F s e q ( F h i d d e n ( x ) ) ) y=\mathfrak{R}(\mathcal{F}_{seq}(\mathcal{F}_{hidden}(x))) y=R(Fseq(Fhidden(x)))
总体来说是先沿着特征向量的维度进行傅里叶变换,而后沿着序列的维度再进行一次,最后只取实数部分。先是把 Embedding 变换到频域,再到前向层做乘法,考虑到傅里叶变换的卷积特性:
F [ f ∗ g ] = F [ f ] ⋅ F [ g ] \mathcal{F}[f* g]=\mathcal{F}[f]\cdot \mathcal{F}[g] F[f∗g]=F[f]⋅F[g]
即频域上的乘法相当于时域上的卷积,所以这种变换本质上是对输入进行了一个大卷积核的卷积。
又考虑到傅里叶变换具有对偶性,所以叠加 N 个 block 之后,信息不断进行傅里叶变换与逆变换,所以作者描述傅里叶层使输入在频域与时域间来回变换,或者也可以理解为在反复进行卷积与乘法,这样便充分融合了各个 token 的信息。
三、项目环境
- 硬件:
- x86 CPU
- NVIDIA TESLA V100
- 框架:
- PaddlePaddle==2.2.1
- 其他库依赖
- paddlenlp==2.1.1
- tqdm==4.27.0
- numpy==1.16.4
- scikit-learn==0.22.1
四、复现精度与指标对比
***官方pytorch复现版本***相关指标:(原论文是 JAX 写的,很多细节没开源,所以选择 pytorch 版本作为对比)
数据集 | FNet-base | FNET-large |
---|---|---|
SST2 (ACC) | 89.45% | 90.48% |
QQP (ACC+F1)/2 | 86.57% | 87.5% |
基于 paddle 复现:
数据集 | FNet-base | FNet-large |
---|---|---|
SST2 (ACC) | 90.13% | 91.06% |
QQP (ACC+F1)/2 | 85.71% | 87.18% |
五、快速开始
在终端运行以下指令以在 SST2 与 QQP上进行 Fine-Tune。
# 1. SST2 FNet-Base
!python train_sst2.py --num_warmup_steps=800 --model_name_or_path='fnet-base' --logging_file="sst2_log_base.txt"
# 2. SST2 FNet-Large
!python train_sst2.py --num_warmup_steps=1000 --model_name_or_path='fnet-large' --logging_file="sst2_log_large.txt"
# 3. QQP FNet-Base
!python train_qqp.py --num_warmup_steps=3000 --model_name_or_path='fnet-base' --logging_file="qqp_log_base.txt"
# 4. QQP FNet-Large
!python train_qqp.py --num_warmup_steps=3000 --model_name_or_path='fnet-large' --logging_file="qqp_log_large.txt"
六、代码结构
├── modeling.py # 模型文件
├── outputs # 训练结束模型保存位置
│ ├── model_config.json
│ ├── model_state.pdparams
│ ├── spiece.model
│ └── tokenizer_config.json
├── paddle_metric.py # metric 文件
├── tokenizer.py # tokenizer 文件
├── train_qqp.py
├── train_sst2.py
├── utils.py
七、复现细节与心得
1. 傅里叶层实现
调用 paddle 中的多维傅里叶变换算子可以很轻松地实现傅里叶层:
class FNetBasicFourierTransform(Layer):
def __init__(self):
super().__init__()
self.fourier_transform = paddle.fft.fftn
def forward(self, hidden_states):
outputs = self.fourier_transform(hidden_states).real()
return outputs
class FNetFourierTransform(Layer):
def __init__(self, hidden_size, layer_norm_eps):
super().__init__()
self.fourier_transform = FNetBasicFourierTransform()
self.output = FNetBasicOutput(hidden_size, layer_norm_eps)
def forward(self, hidden_states):
self_outputs = self.fourier_transform(hidden_states)
fourier_output = self.output(self_outputs, hidden_states)
return fourier_output
值得注意的是,paddle.fft.fftn() 使用 partial 封装会可能导致出错堆栈信息复杂,这里选择直接调用。
2. 权重转换
模型权重转换时注意 paddle 的线性层权重需转置,而每个人的线性层名称可能都不同,包括 pooler,
mlp,linear,projection 等等。所以应先仔细阅读源码,确保权重转换正确。FNet 的转换代码如下,可根据具体模型修改:
def convert_pytorch_checkpoint_to_paddle(
pytorch_checkpoint_path="pytorch_model.bin",
paddle_dump_path="model_state.pdparams",
version="old", ):
hf_to_paddle = {
"embeddings.LayerNorm": "embeddings.layer_norm",
".LayerNorm.": ".layer_norm.",
"encoder.layer": "encoder.layers",
}
do_not_transpose = []
if version == "old":
hf_to_paddle.update({
"predictions.bias": "predictions.decoder_bias",
".gamma": ".weight",
".beta": ".bias",
})
do_not_transpose = do_not_transpose + ["predictions.decoder.weight"]
pytorch_state_dict = torch.load(
pytorch_checkpoint_path, map_location="cpu")
paddle_state_dict = OrderedDict()
for k, v in pytorch_state_dict.items():
is_transpose = False
if k[-7:] == ".weight":
# embeddings.weight and LayerNorm.weight do not transpose
if all(d not in k for d in do_not_transpose):
if ("embeddings." not in k and ".LayerNorm." not in k) or "embeddings.projection" in k or "seq_relationship" in k or "classifier" in k:
if v.ndim == 2:
v = v.transpose(0, 1)
is_transpose = True
oldk = k
for hf_name, pd_name in hf_to_paddle.items():
k = k.replace(hf_name, pd_name)
print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}")
paddle_state_dict[k] = v.data.numpy()
paddle.save(paddle_state_dict, paddle_dump_path)
return paddle_state_dict
如果训练对齐过程中发现复现模型不收敛,请先检查权重是否正确转换!!(我就是 Layer Norm 层没有正确转换,结果加载模型后 weight 全为1,在这里卡了好久…)
3. 对齐过程中的数据
除去训练对齐,其他几个部分都可以使用 fake data,更加简便。
def gen_fake_data():
fake_data = np.random.randint(1, 30522, size=(4, 64)).astype(np.int64) # 注意 index 别超了,30522 改成自己的词表大小
fake_label = np.array([0, 1, 1, 0]).astype(np.int64)
np.save("fake_data.npy", fake_data)
np.save("fake_label.npy", fake_label)
4. 心得
本次论文复现,由于超参等等很多细节没有开源,所以复现难度比较大。(实名抨击 Google 店大欺客!还是用 TPU 训练的,这不是软广是啥…) 前期除了对文献本身研读以外,同时在网上查阅了很多资料,以及其他框架实现的代码。(结果最后发现还是 HF 的最清楚)论文复现过程中, 所学到的领域知识,令我获益匪浅。
纸上得来终觉浅,绝知此事要躬行。古人诚不欺我!
更多推荐
所有评论(0)