一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

基于相似关系知识蒸馏的模型轻量化方法及相关装置与流程

2022-12-19 21:36:30 来源:中国专利 TAG:


1.本发明涉及机器学习技术领域,尤其涉及基于相似关系知识蒸馏的模型轻量化方法及相关装置。


背景技术:

2.在实际的应用中,目标检测模型参数量、计算量大很难部署在移动端平台上,所以对模型进行轻量化的研究是十分必要的。模型压缩和加速是两个不同的话题,有时候压缩并不一定能带来加速的效果,有时候又是相辅相成的。压缩重点在于减少网络参数量,加速则侧重在降低计算复杂度、提升并行能力等。模型压缩和加速可以从多个角度来优化。总体来看,主要分为两个方面:算法层压缩加速和硬件层加速。
3.在算法层压缩加速这个维度的模型轻量化的研究中涌现出众多的研究成果。近期研究工作可分为四大类:低秩分解、模型量化、网络剪枝,以及知识蒸馏。其中知识蒸馏是一种将知识从教师网络转移到目标神经网络的技术,通过预先培训教师网络,将模型在实现任务的过程中所学习的知识特征,通过蒸馏损失告知学生网络问题的解决方案与中间过程,这种信息监督方式,能够督促学生网络进行快速有效的学习。但是现有的知识蒸馏方法都是基于手工设定将教师网络与学生网络联系起来,但是基于手工设定的选择往往会建立起无效的知识传播途径,且总是在降低模型体量的同时,一定程度上降低了模型的精度,影响蒸馏的效果,因此,如何在保证模型精度的同时,进一步减少模型的体量,是如今迫切需要解决的问题。


技术实现要素:

4.为了解决上述技术问题,本发明提出基于相似关系知识蒸馏的模型轻量化方法及相关装置,采用的基础是深度学习技术的有效特征提取技术,可以将教师特征知识自适应地转移到学生层次的多个层次,从而提高轻量化模型的分类性能,在保证模型精度的同时,进一步减少模型的体量。
5.为了达到上述目的,本发明的技术方案如下:
6.基于相似关系知识蒸馏的模型轻量化方法,包括如下步骤:
7.将样本数据分别输入教师网络和学生网络,获得所述教师网络输出的获得教师特征以及学生网络输出的学生特征;
8.计算教师特征和学生特征之间的相似性,并基于所述相似性确定注意知识迁移函数;基于样本数据与学生特征间的损失确定子损失函数;
9.将所述子损失函数和加权后的注意知识迁移函数之和作为学生网络训练的整体损失函数。
10.优选地,所述教师网络采用resnet50,所述学生网络采用resnet18。
11.优选地,所述教师网络和学生网络均采用全局平均池化和通道池化对每层提取的特征进行处理。
12.优选地,所述计算教师特征和学生特征之间的相似性,并基于所述相似性确定注意知识迁移函数,具体包括如下步骤:
13.采用多头注意力机制计算教师特征和学生特征之间的相似性,
14.获取相似性最大的教师特征向量和相似性最大的学生特征向量;
15.通过所述相似性最高的教师特征和相似性最高的学生特征确定注意知识迁移函数。
16.优选地,所述子损失函数为交叉熵损失函数。
17.优选地,所述整体损失函数,公式为:
18.l
total
=l
cls
β

l
at
19.式中,l
cls
为子损失函数;l
at
为注意知识迁移函数,β'为控制蒸馏损失影响的权衡参数。
20.优选地,所述控制蒸馏损失影响的权衡参数β'的确认公式为:
[0021][0022]
式中,代表着初始权重的蒸馏损失,γ是常系数,ne代表着整个训练过程中的第n次循环, n是经验值代表着循环次数。
[0023]
基于上述内容,本发明还公开了一种基于相似关系知识蒸馏的模型轻量化装置,包括:处理模块、计算模块和确定模块,其中,
[0024]
所述处理模块,用于将样本数据分别输入教师网络和学生网络,获得所述教师网络输出的获得教师特征以及学生网络输出的学生特征;
[0025]
所述计算模块,用于计算教师特征和学生特征之间的相似性,并基于所述相似性确定注意知识迁移函数;用于基于样本数据与学生特征间的损失确定子损失函数;
[0026]
所述确定模块,用于将所述子损失函数和加权后的注意知识迁移函数之和作为学生网络训练的整体损失函数。
[0027]
基于上述内容,本发明还公开了一种计算机设备,包括:存储器,用于存储计算机程序;处理器,用于执行所述计算机程序时实现如上述任一所述的方法。
[0028]
基于上述内容,本发明还公开了一种可读存储介质,所述可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如上述任一所述的方法。
[0029]
基于上述技术方案,本发明的有益效果是:
[0030]
1)本发明采用全局平均池化,降低参数量,减少过拟合(正则化)、将空间信息进行汇总,增加泛化能力;采用通道池化进行降维在减少特征图数量的同时保留其显著特征;
[0031]
2)本发明分块相似性计算可以有效的在减少相似关系的计算量的同时进一步减少教师网络有差别特征对学生特征知识的引导;
[0032]
3)本发明提出的新颖的蒸馏损失函数,通过损失衰减因子,减少教师网络在学生网络后期学习过程中信息的干扰;
[0033]
4)本发明通过计算教师网络与学生网络之间的空间特征相似性计算,有效的提升了教师网络与学生网络之间知识传递的效果,提升了轻量化学生网络的性能。
附图说明
[0034]
图1是一个实施例中基于相似关系知识蒸馏的模型轻量化方法流程图;
[0035]
图2是一个实施例中教师网络和学生网络相似性确认示意图;
[0036]
图3是一个实施例中一种基于相似关系知识蒸馏的模型轻量化系统的结构示意图;
[0037]
图4是一个实施例中一种计算机设备的结构框图;
[0038]
图5是一个实施例中基于相似关系知识蒸馏的模型轻量化方法的程序产品的结构示意图。
具体实施方式
[0039]
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述。
[0040]
如图1所示,本实施例提供基于相似关系知识蒸馏的模型轻量化方法,包括如下步骤:
[0041]
步骤s101,将样本数据分别输入教师网络和学生网络,获得所述教师网络输出的获得教师特征以及学生网络输出的学生特征。
[0042]
本实施例中,考虑到教师网络与学生网络之间差异性质过大,会导致蒸馏效果一定程度的下降,因此,教师网络采用resnet50,学生网络采用resnet18。对于resnet50和resnet18 首先都会使用一个7
×
7大小,输出通道数为64,步长(stride)为2的卷积。然后resnet50 会使用4个卷积层进行计算。每一个block都包含两个3
×
3大小的卷积,采用的通道数分别是64、128、256和512。在网络的最后采用一个平均池化以及softmax损失函数。对于 resnet18同样使用4个卷积层进行计算,每一个block结构相同包括两个1
×
1大小的卷积,以及一个3
×
3大小的卷积。每一层所所使用的block的个数分别为3、4、6、3,具体的每一层的卷积通道数分别是第一层为64、64、32;第二层为128,128,512;第三层为256、 256、1024;最后一层为512、512、1024。
[0043]
步骤s102,计算教师特征和学生特征之间的相似性,并基于所述相似性确定注意知识迁移函数;基于样本数据与学生特征间的损失确定子损失函数。
[0044]
本实施例中,学生网络和教师网络分别通过卷积运算实现对不同层的特征进行提取,然后,分别对每一层提取出来的特征进行全局平均池化和通道池化,其中,
[0045]
全局平均池化:与传统的池化操作不同,所采用的池化窗口大小是整个特征图的大小,并对整个特征图的权重值求平均值,这样做的优点在于,常见的全连接层容易过拟合,会影响整个网络的泛化能力,而使用全局平均池化层可以加强特征图与类别之间的关系,能够有效的避免过拟合.并且全局平均池化可以将空间信息进行汇总,增加模型的鲁棒性。
[0046]
通道池化:其对应的作用在于降维,在减少特征图数量的同时,保留其显著特征,其还可以创建特征映射的一对一投影,以跨通道共享特征,或者增加特征映射的数量。
[0047]
设计多头注意机制模块,首先通过分块操作对学生网络和教师网络的特征图进行分块。这样做的优点在于,计算特征空间位置之间的相关性,当特征图较大时,模型的计算量与参数量将会巨大,影响模型的计算速度,降低对硬件设备的计算需求。
[0048]
为了进一步降低模型的计算复杂度,则不计算所有空间位置的相关性,而是将特征图进行了分块处理,然后在每组中进行一次蒸馏,然后进一步将分块内的特征平均到单
个向量中提取知识,该操作大大降低提出方法的复杂性。
[0049]
参见图2,图2为教师网络和学生网络相似性确认示意图,对每一个分块的特征都进行了位置编码的设计,位置编码被用来在不同的实例中共享公共信息。具体的公式如下所示:
[0050]
pe
(pos,2i)
=sin(pos/10000
2i/dmodel
)
[0051]
pe
(pos,2i)
=cos(pos/10000
2i/dmodel
)
[0052]
式中,pe代表位置编码,pos代表第i个元素的具体位置,dmodel等于512。
[0053]
为了识别教师特征与学生特征之间的相似性,采用多头注意力机制进行计算。通过三个可学习的权重矩阵q,k,v计算了不同分块位置之间的相似性。相似性具体的计算公式可用以下公式表示:
[0054][0055]
multiheadself(qs,k
t
,y
t
)=concat(head1,head2,.....,headh)wo[0056][0057]
式中,代表三个不同的权重矩阵;dk代表着k向量的维度,dv代表着v向量的维度;head1~headh代表着多头注意力机制中8个不同的头,在本发明中h=8;e代表着教师网络和学生之间关系的注意值用softmax计算,通过利用e可以知道第t个教师特征与整个学生特征之间的关系注意值,使得能够有选择地将知识转移到学生特征中,qs代表着学生矩阵,k
t
和v
t
代表着教师矩阵和教师矩阵的转置。
[0058]
为了使相似度计算特征知识能够转移到学生网络中,设计了一种attention loss用作注意知识迁移函数:
[0059][0060]
在attention loss中令代表着相似性最大的时候的学生特征向量,代表着相似性最大的时候的教师网络相似特征矩阵向量,其和基于相似性e确定,ws分别表示学生特征的权重,l(ws,x)代表着标准交叉熵损失,j代表着想要转移的相似注意力图和学生注意力图的索引,β代表着权衡参数,p代表使用p范数。
[0061]
步骤s103,将所述子损失函数和加权后的注意知识迁移函数之和作为学生网络训练的整体损失函数。
[0062]
本实施例中,对于整体损失函数的设计如下:
[0063]
l
total
=l
cls
β

l
at
[0064]
式中,l
cls
是带有真实标签的分类损失即子损失函数,β’是控制蒸馏损失影响的权衡参数。对于l
cls
使用交叉熵损失。整体模型中只使用代表所有可能特征对加权相似度的l
at
进行训练,因此l
at
不需要昂贵的hessian计算来将其参数与l
cls
连接。β’的作用在于教师网
络监督学生网络训练,会在后期一定程度上抑制学生学习,所以需要慢慢衰减的监督力度。对于β’的定义如下:
[0065]
β

=θ
×
γ
ne/n
[0066]
式中,代表着初始权重的蒸馏损失,γ是常系数,ne代表着整个训练过程中的第n次循环, n是经验值代表着循环次数。
[0067]
如图3所示,本实施例提供一种基于相似关系知识蒸馏的模型轻量化装置100,包括:处理模块110、计算模块120和确定模块130,其中,
[0068]
所述处理模块110,用于将样本数据分别输入教师网络和学生网络,获得所述教师网络输出的获得教师特征以及学生网络输出的学生特征;
[0069]
所述计算模块120,用于计算教师特征和学生特征之间的相似性,并基于所述相似性确定注意知识迁移函数;用于基于样本数据与学生特征间的损失确定子损失函数;
[0070]
所述确定模块130,用于将所述子损失函数和加权后的注意知识迁移函数之和作为学生网络训练的整体损失函数。
[0071]
如图4所示,本实施例提供一种计算机设备,包括:200包括至少一个存储器210、至少一个处理器220以及连接不同平台系统的总线230。
[0072]
存储器210可以包括易失性存储器形式的可读介质,例如随机存取存储器(ram)211和 /或高速缓存存储器212,还可以进一步包括只读存储器(rom)213。
[0073]
其中,存储器210还存储有计算机程序,计算机程序可以被处理器220执行,使得处理器220实现上述任一项方法的步骤,其具体实现方式与上述方法实施方式中记载的实施方式、所达到的技术效果一致,部分内容不再赘述。
[0074]
存储器210还可以包括具有至少一个程序模块215的实用工具214,这样的程序模块 215包括但不限于:操作系统、一个或者多个应用程序、其它程序模块以及程序数据,这些示例的每一个或某种组合中可能包括网络环境的实现。
[0075]
相应的,处理器220可以执行上述计算机程序,以及可以执行实用工具214。
[0076]
总线230可以为表示几类总线结构的一种或多种,包括存储器总线或者存储器控制器、外围总线、图形加速端口、处理器或者使用多种总线结构的任意总线结构的局域总线。
[0077]
电子设备200也可以与一个或多个外部设备240例如键盘、指向设备、蓝牙设备等通信,还可与一个或者多个能够与该电子设备200交互的设备通信,和/或与使得该电子设备 200能与一个或多个其它计算设备进行通信的任何设备(例如路由器、调制解调器等)通信。这种通信可以通过输入输出接口250进行。并且,电子设备200还可以通过网络适配器260 与一个或者多个网络(例如局域网(lan),广域网(wan)和/或公共网络,例如因特网)通信。网络适配器260可以通过总线230与电子设备200的其它模块通信。应当明白,尽管图中未示出,可以结合电子设备200使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理器、外部磁盘驱动阵列、raid系统、磁带驱动器以及数据备份存储平台等。
[0078]
本技术还提供了一种计算机可读存储介质,该计算机可读存储介质用于存储计算机程序,所述计算机程序被执行时实现上述任一方法的步骤,其具体实现方式与上述方法实施方式中记载的实施方式、所达到的技术效果一致,部分内容不再赘述。
[0079]
参见图5,图5示出了本技术提供的基于相似关系知识蒸馏的模型轻量化方法的程
序产品300的结构示意图。程序产品300可以采用便携式紧凑盘只读存储器(cd-rom)并包括程序代码,并可以在终端设备,例如个人电脑上运行。然而,本发明的程序产品300不限于此,在本技术中,可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。程序产品300可以采用一个或多个可读介质的任意组合。可读介质可以是可读信号介质或者可读存储介质。可读存储介质例如可以为但不限于电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、随机存取存储器(ram)、只读存储器(rom)、可擦式可编程只读存储器(eprom或闪存)、光纤、便携式紧凑盘只读存储器(cd-rom)、光存储器件、磁存储器件、或者上述的任意合适的组合。
[0080]
计算机可读存储介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了可读程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。可读存储介质还可以是任何可读介质,该可读介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。可读存储介质上包含的程序代码可以用任何适当的介质传输,包括但不限于无线、有线、光缆、rf等,或者上述的任意合适的组合。可以以一种或多种程序设计语言的任意组合来编写用于执行本发明操作的程序代码,程序设计语言包括面向对象的程序设计语言诸如java、 c 等,还包括常规的过程式程序设计语言诸如c语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户计算设备上部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。在涉及远程计算设备的情形中,远程计算设备可以通过任意种类的网络,包括局域网(lan)或广域网(wan),连接到用户计算设备,或者,可以连接到外部计算设备(例如利用因特网服务提供商来通过因特网连接)。
再多了解一些

本文用于创业者技术爱好者查询,仅供学习研究,如用于商业用途,请联系技术所有人。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献