训练 CV 模型新思路来了:用 NLP 大火的 Prompt 替代微调,性能全面提升
Prompt tuning,作为 NLP 领域中的一个"新宠",甚至曾被学者誉为 NLP 预训练新范式。那么,它能否借鉴到 CV 领域并产生同样的成绩呢?
现在,来自康奈尔大学和 Meta AI 等机构,通过 Prompt 来调整基于 Transformer 的视觉模型,结果发现:完全可以!
比起全面微调,Prompt 性能提升显著。无论模型的规模和训练数据怎么变,24 种情况中有 20 种都完全胜出。
与此同时,它还能大幅降低每项任务所需的存储成本。
只使用不到 1% 的模型参数
大家一贯使用的全面微调(full fine-tuning),需要为每个下游任务存储和部署单独的主干参数副本,成本太高,尤其是现在基于 Transformer 的模型越来越大,已经超过 CNN 架构。
所谓 Prompt,最初指的是在输入文本中预编语言指令,以便预培训的语言模型后续可以直接理解各种下游任务。它曾让 GPT-3 即使在少样本或零样本的情况下表现出很强的泛化能力。
最近一些成果则表明,Prompt 与完全微调的性能相当,参数存储量还减少了 1000 倍。NLP 中的高超性能让不少人开始在 CV 领域中探索 Prompt 的魔力,不过都只局限于跨模态任务中文本编码器的输入。
在本文中,作者将他们所提出的 Visual Prompt Tuning 方法,简称为 VPT。这是首次有人将 Prompt 应用到视觉模型主干(backbone),并做出成果。具体来说,比起全面微调,VPT 受最新大型 NLP 模型调整方法的启发,只在输入空间中引入少量可特定某任务训练的参数(不到模型参数的 1%),同时在训练下游任务期间冻结(freeze)预训练模型的主干。
在实操中,这些附加参数只用预先加入到每个 Transformer 层的输入序列中,并在微调期间与线性 head 一起学习。
他们一共探索出两种变体:
VPT-Deep 变体为 Transformer 编码器每层的输入预先设置一组可学习的参数;
VPT-Shallow 变体则仅将提示参数插入第一层的输入。
两者在下游任务的训练过程中,只有特定于任务的提示和线性头的参数会更新,而整个 Transformer 编码器被冻结。
接下来,是骡子是马?拉出来溜溜~
20/24 的优胜率
实验涉及两种在 ImageNet-21k 上预训练好的主干,一个来自 Vision Transformer,一个来自 Swin Transformer。
进行对比的微调方法有三大种,7 小种,包括:
(1)完全微调:更新所有主干和分类头(classification head)参数
(2)以分类头为重点的微调,包括 Linear、Partial-k 和 Mlp-k 三种;
(3)以及在微调过程中更新一个主干子集参数或向主干添加新的可训练参数的方法,分为 Sidetune、Bias 和 Adapter 三种。
实验的数据集有两组,一共涉及 24 个跨不同领域的下游识别任务,包括:
(1)由 5 个基准细粒度视觉分类任务组成的 FGVC;
(2)由 19 个不同视觉分类集合组成的 VTAB-1k,细分为使用标准相机拍摄的自然图像任务(Natural)、用专用设备(如卫星图像)捕获的图像任务(Specialized)以及需要几何理解的任务(Structured),比如物体计数。
测得每项任务上的平均准确度后,得出的主要结果如下:
VPT-Deep 在 24 个任务中有 20 个的表现都优于全面微调,同时使用的总模型参数显著减少(1.18× vs. 24.02×);
要知道,在 NLP 领域中 Prompt 再厉害,性能也不会超过全面微调。这说明 Prompt 很适用于视觉 Transformer 模型。和其他微调方法相比(b、c 组),VPT-Deep 的性能则全部胜出。
此外,选择不同主干参数规模和模型规模的 ViT(ViT-B、ViT-L 和 ViT-H)进行测试还发现,VPT 方法不会受影响,依然基本保持性能领先。
而在 Swin Transformer 中,全面微调法的平均准确度虽然更高,但也付出了巨大的参数代价。其他微调方法则全部不敌 VPT。
作者介绍
一作贾梦霖,康奈尔大学信息科学(Information Science)博士生,主要研究方向为视觉和文本信息的细粒度识别,截至目前共发表过 4 篇顶会。
共同一作为唐路明,也是康奈尔大学的一位计算机博士在读学生,本科毕业于清华大学数学与物理专业。
他的主要研究方向为机器学习和计算机视觉的交叉领域。
论文地址:
https://arxiv.org/abs/2203.12119
2022-05-06 01:24:42