一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

一种抽烟打电话分类注意力模型的知识蒸馏方法

2022-12-20 00:54:12 来源:中国专利 TAG:


1.本发明涉及计算机视觉技术领域,具体涉及一种抽烟打电话分类注意力模型的知识蒸馏方法。


背景技术:

2.由于注意力机制在自然语言处理领域成功的应用,注意力在视觉领域被寄予较大的期望,受到越来越多的学者的关注。基于注意力机制的图片分类算法凭借独特的注意力机制在生物结构上更像人类视觉,在大型公共数据集imagenet上超越了卷积神经网络。目前在大多数实际的视觉任务中,注意力模型的稳定性仍然不如卷积神经网络。在实际应用中运用注意力模型是一个非常有价值的研究课题。
3.目前,知识蒸馏是比较容易实现训练注意力模型的方法。知识蒸馏是深度学习模型中一种训练方法,利用强大的教师网络引导学生网络进行联合训练,学生网络结合教师网络输出的软标签概率分布进行学习。故在不增加学生网络复杂度的情况下,减少数据依赖性的同时提升学生网络的性能。因此,许多实际工程和研究员都使用知识蒸馏来进行模型训练。
4.当前主流的知识蒸馏算法的主要步骤如图1所示。首先,图片数据经过教师网络和学生网络得到各自的预测标签;接着,将教师网络的预测标签与学生网络的预测标签计算损失;然后,学生网络的预测标签再与真实图片标签计算损失;最后,将两部分损失进行加权组合得到最终的损失值。
5.对于复杂的实际工程场景中,良好的知识蒸馏方法往往可以让注意力模型学习到更多的图片语义信息,从而提升注意力模型的性能。然而,现有的抽烟打电话分类场景中知识蒸馏训练注意力模型还存在以下问题:(1)抽烟打电话分类场景中,数据集的量有限,注意力模型过早陷入局部最优的问题;(2)作为学生网络的注意力模型的性能无法达到对应的cnn教师网络。


技术实现要素:

6.本发明为了克服以上技术的不足,提供了一种抽烟打电话分类注意力模型的知识蒸馏方法。通过本发明所述的知识蒸馏方法训练的注意力模型减少了对数据量的依赖;对于注意力模型的性能,本发明相比较传统的知识蒸馏方法,具有更好的分类效果。
7.本发明克服其技术问题所采用的技术方案是:
8.一种抽烟打电话分类注意力模型的知识蒸馏方法,包括如下步骤:
9.s1、获取抽烟打电话数据集并对抽烟打电话数据集进行预处理;
10.s2、将预处理后的抽烟打电话数据集输入到预先建立的教师网络中进行训练;
11.s3、根据抽烟打电话数据集,配置额外数据集进行混合,得到混合数据集,其中,额外数据集是与抽烟打电话数据集具有互斥性的数据集;
12.s4、将混合数据集分别输入到步骤s2中训练好的教师网络和基于知识蒸馏的学生
网络中,分别得到预测结果;
13.s5、将教师网络的预测结果和学生网络的预测结果分别与真实标签进行比较,对预测出错的标签消化纠错;
14.s6、通过知识蒸馏训练迭代,将步骤s5中得到的消化纠错结果与学生网络的预测结果进行损失计算,将计算得到的损失值进行反向传播,以此完成知识蒸馏训练过程。
15.进一步地,教师网络采用resnet50;基于知识蒸馏的学生网络采用vit注意力模型。
16.进一步地,步骤s1具体包括:
17.在不同时间、不同天气下的实际工程场景中采集不同角度、不同年龄段的行人样本图片,行人样本图片类别包括抽烟行人、打电话行人和既不抽烟也不打电话的行人;
18.从样本集中将人脸作为关键局部信息中心点选取预设大小的像素图像,将选取的关键局部信息图像作为后续相似度计算输入图像;
19.提取关键局部信息图像的hog特征,并将提取出来的hog特征向量与相似度队列中的特征向量计算余弦相似度;
20.根据计算得到的余弦相似度进行判断:若相似度高于设定阈值,则该样本为无效样本,不作为知识蒸馏的训练样本;若相似度低于设定阈值,则该样本为有效样本,作为知识蒸馏的训练样本,同时该样本加入相似度队列中作为相似度比对样本。
21.进一步地,hog特征提取包括如下步骤:
22.1)将rgb三维图像灰度化:
23.i(x,y)=0.3*ir(x,y) 0.59ig(x,y) 0.11ib(x,y)
ꢀꢀ
(1)
24.上式中,i(x,y)是灰度化之后的图像像素值,ir(x,y),ig(x,y),ib(x,y)分别是图像r、g、b三维通道上的像素值;
25.2)对灰度化之后的图像进行gamma校正:
[0026][0027]
上式中,g(x,y)是校正之后的图像像素值;
[0028]
3)计算图像上每个像素的梯度:
[0029][0030]
上式中,t
x
(x,y),ty(x,y)分别为图像像素值水平方向和垂直方向的梯度,t(x,y)为图像每个像素上的梯度值;
[0031]
4)将图像划分成小cell,同时统计每个cell中像素的梯度,形成cell的描述符;
[0032]
5)将若干个cell组成为一个block,block中的cell描述符组合得到该block的描述符;
[0033]
6)将所有block的描述符串联起来就形成该图像的hog特征。
[0034]
进一步地,步骤s3中,设抽烟打电话数据集与额外数据集的混合比例为n1:n2,在知识蒸馏训练过程中,n1与n2的总和保持不变。
[0035]
进一步地,步骤s6中,在知识蒸馏训练的开始阶段,n1<n2;随着知识蒸馏训练的
迭代次数增加,n1动态增加、同时n2动态减小,且n1中抽烟行人、打电话行人和既不抽烟也不打电话的行人三种类别的数量也动态增加;在知识蒸馏训练的最后阶段,n1>n2。
[0036]
进一步地,在知识蒸馏训练的开始阶段,n1:n2=14:50,n1中抽烟行人、打电话行人和既不抽烟也不打电话的行人三种类别的数量比例为5:5:4;到知识蒸馏训练的最后阶段,n1:n2=50:14,n1中抽烟行人、打电话行人和既不抽烟也不打电话的行人三种类别的数量比例为17:17:16。
[0037]
进一步地,步骤s4中,
[0038]
教师网络的预测结果表示为
[0039]
学生网络的预测结果为
[0040]
其中,n代表样本的类别数,t代表教师网络预测类别索引值,s代表教师网络预测类别索引值;
[0041]
因抽烟数据样本与打电话数据样本之间并不互斥,即图片可同时存在抽烟和打电话,故教师网络和学生网络都采用sigmoid激活函数,sigmoid激活函数的表达式如下:
[0042][0043]
上式中,z为网络最后一层的输入标签。
[0044]
进一步地,步骤s5中,设教师网络预测标签最大值的下标索引为m1,学生网络预测标签最大值的下标索引为m2,真实标签最大值的下标索引为m3,若m1≠m3,则说明教师网络的预测标签出错,若m2≠m3,则说明学生网络的预测标签出错;
[0045]
对于预测出错的标签,则需要进行消化纠错,消化纠错的公式如下:
[0046]
上式中,y为预测标签,α为随机数且0<α≤1。
[0047]
进一步地,步骤s6中,损失计算的损失函数为余弦损失函数σ(x),其计算式如下:
[0048][0049]
上式中,x1为学生网络的预测标签,x2为经过消化纠错之后的预测标签;
[0050]
设学生网络的预测标签为ys,教师网络的预测标签y
t
,消化纠错之后学生网络的预测标签为消化纠错之后教师网络的预测标签为当前知识蒸馏训练迭代次数为k,ys,k代入到公式(6),求得最后的损失,则最后的损失函数计算如下:
[0051][0052]
本发明的有益效果是:
[0053]
本发明通过调节知识蒸馏的训练输入的混合数据集,具体是配置额外数据集以及设置抽烟打电话数据集和额外数据集的比例,使得网络即使在有限数据样本的情况下也能发挥不错的效果,减少了对数据量的依赖。
[0054]
本发明通过调节教师网络和学生网络的预测标签,具体是将教师网络的预测结果和学生网络的预测结果分别与真实标签进行比较,对预测出错的标签消化纠错,使得学生
网络在正确的知识上进行学习,从而避免学生网络得到教师网络错误的信息,使得学生网络的性能达到教师网络。本发明在知识蒸馏训练过程中,充分考虑到教师网络和学生网络对抽烟打电话分类的表现效果不一致性,使得样本在知识蒸馏中的作用发挥最大,相比于传统的知识蒸馏方法,具有更好的分类效果。
[0055]
本发明的方法除了在抽烟打电话分类应用之外,还可以向其他工业场景图片分类领域推广。基于本发明的知识蒸馏方法可应用于简单场景中的图片分类,例如:安全帽分类、行人性别分类等。本发明的方法在训练过程中根据教师网络和学生网络对样本的表现效果,增强教师网络和学生网络在训练过程中不同的作用,这种方式可以应用在大多数知识蒸馏中,具有一定的推广意义。
附图说明
[0056]
图1为传统的知识蒸馏算法的流程示意图。
[0057]
图2为本发明实施例的抽烟打电话数据集预处理过程示意图。
[0058]
图3为本发明实施例的抽烟打电话数据集和额外数据集混合的流程示意图。
[0059]
图4为本发明实施例的知识蒸馏方法的流程示意图。
[0060]
图5为本发明实施例的试验验证结果图,其中,图5(a)为训练时学生网络在训练集上的准确度变化曲线,图5(b)为训练时学生网络在测试集上的损失值变化曲线。
具体实施方式
[0061]
为了便于本领域人员更好的理解本发明,下面结合附图和具体实施例对本发明做进一步详细说明,下述仅是示例性的不限定本发明的保护范围。
[0062]
本实施例所述的一种抽烟打电话分类注意力模型的知识蒸馏方法,包括如下步骤:
[0063]
s1、获取抽烟打电话数据集并对抽烟打电话数据集进行预处理。
[0064]
(1)本实施例中,获取抽烟打电话数据集具体包括如下:
[0065]
从监控视频中获取行人样本图片,具体是在不同时间、不同天气下的实际工程场景中采集不同角度、不同年龄段的行人样本图片,行人样本图片类别包括抽烟行人、打电话行人和既不抽烟也不打电话的行人。
[0066]
(2)考虑到采集的方式为摄像头实时拍摄,拍摄过多同一行人的图片样本,造成图片之间相似度过高。因此本实施例对拍摄的图片进行预处理,去掉相似度过高的样本。如图2所示,对抽烟打电话数据集进行预处理具体包括如下:
[0067]
应理解地,获取的行人样本图片大小不一致,图片中除关键信息之外,干扰信息也比较多,例如,图像上光照影响、行人距离影响等,故本实施例在行人样本图片的基础上,将人脸作为关键局部信息中心点,从样本集中选取像素大小为128*128的区域作为关键局部信息图像,作为后续相似度计算输入图像。
[0068]
提取关键局部信息图像的hog特征,并将提取出来的hog特征向量与相似度队列中的特征向量计算余弦相似度,其中,hog特征是方向梯度直方图,其核心思想是图像局部目标的形状和表象可以被梯度或者边缘方向分布很好的描述。
[0069]
根据计算得到的余弦相似度进行判断:若相似度高于设定阈值,则该样本为无效
样本,不作为知识蒸馏的训练样本;若相似度低于设定阈值,则该样本为有效样本,作为知识蒸馏的训练样本,同时该样本加入相似度队列中作为相似度比对样本。具体地,相似度队列的容量预设为100维,余弦相识度计算时,输入样本与队列中所有样本都进行计算,只有所有的结果都《0.9,输入样本才为有效样本;有效样本进入队列时,位于队列首位的样本则出队列;将队列中其余样本向前平移一位,有效样本存放在队列的末尾。
[0070]
具体地,hog特征提取包括如下步骤:
[0071]
1)将rgb三维图像灰度化:
[0072]
i(x,y)=0.3*ir(x,y) 0.59ig(x,y) 0.11ib(x,y)
ꢀꢀ
(1)
[0073]
上式中,i(x,y)是灰度化之后的图像像素值,ir(x,y),ig(x,y),ib(x,y)分别是图像r、g、b三维通道上的像素值;
[0074]
2)对灰度化之后的图像进行gamma校正,减少图像噪音的干扰:
[0075][0076]
上式中,g(x,y)是校正之后的图像像素值;
[0077]
3)计算图像上每个像素的梯度:
[0078][0079]
上式中,t
x
(x,y),ty(x,y)分别为图像像素值水平方向和垂直方向的梯度,t(x,y)为图像每个像素上的梯度值;
[0080]
4)将图像划分成小cell,优选地,每个小cell为16*16个像素,同时统计每个cell中像素的梯度,形成cell的描述符;
[0081]
5)将若干个cell组成为一个block,优选地,4*4个cell组成一个block,block中的cell描述符组合得到该block的描述符;
[0082]
6)将所有block的描述符串联起来就形成该图像的hog特征。
[0083]
s2、将预处理后的抽烟打电话数据集输入到预先建立的教师网络中进行训练,优选地,教师网络采用resnet50,采用resnet50的目的是,resnet50将残差结构引入到传统深度学习模型中解决了梯度消失问题,使其性能在网络深度增加的同时也得到提升。
[0084]
s3、根据抽烟打电话数据集,配置额外数据集进行混合,得到混合数据集,其中,额外数据集是与抽烟打电话数据集具有互斥性的数据集。
[0085]
具体地,额外数据集为图片分类数据集,为了避免对抽烟打电话数据集造成干扰,额外数据集与抽烟打电话数据集需要做到互斥性,即图片类别需要互斥,同时额外数据集的数据量也不能太小,故本实施例选取的额外数据集为cifar-10数据集,cifar-10数据集有10个类别,每个类别有6000张图片,额外数据集共计6万张图片,其中5万张图片为训练集,1万张图片为测试集。
[0086]
将抽烟打电话数据集和额外数据集混合得到混合数据集,设抽烟打电话数据集与额外数据集的混合比例为n1:n2,如图3所示为本实施例的混合流程。
[0087]
n1和n2的总和为每批次知识蒸馏训练输入的样本数,n1与n2的总和保持不变,本实施例中每批次知识蒸馏训练输入的样本数为64,故n1 n2=64,每批次样本中某个数据集
占比越大则使得该数据集对知识蒸馏训练的引导作用越大。因此,在本实施例的知识蒸馏训练的开始阶段,设置抽烟打电话数据集n1<额外数据集n2,优选n1:n2=14:50,加大额外数据集在训练时对学生网络的引导作用,可以使学生网络预热更快。其次,随着知识蒸馏训练的迭代次数增加,将n1动态增加、同时n2动态减小,以此来加强抽烟打电话数据集在训练时对学生网络的引导作用;同时,在训练过程中,n1中抽烟行人、打电话行人和既不抽烟也不打电话的行人三种类别的数量占比也在动态变化,则数据集的每个类别的样本数也需要动态调整,在知识蒸馏训练的开始阶段,抽烟打电话数据集n1的占比为14,n1中抽烟行人、打电话行人和既不抽烟也不打电话的行人三种类别的数量比例为5:5:4。到知识蒸馏训练的最后阶段,n1>n2,具体是n1:n2=50:14,n1中抽烟行人、打电话行人和既不抽烟也不打电话的行人三种类别的数量比例为17:17:16。故抽烟行人、打电话行人和既不抽烟也不打电话的行人三种类别数量在训练时从5:5:4动态增加到17:17:16。
[0088]
s4、将混合数据集分别输入到步骤s2中训练好的教师网络和基于知识蒸馏的学生网络中,分别得到预测结果。
[0089]
本实施例中,基于知识蒸馏的学生网络优选采用vit注意力模型,采用vit注意力模型的目的是,vit作为将注意力结构进入到图片分类领域中,在许多基准测试中达到了sota(最好,最佳)性能。
[0090]
教师网络的预测结果表示为
[0091]
学生网络的预测结果为
[0092]
其中,n代表样本的类别数,t代表教师网络预测类别索引值,s代表教师网络预测类别索引值;
[0093]
因抽烟数据样本与打电话数据样本之间并不互斥,即图片可同时存在抽烟和打电话,故教师网络和学生网络都采用sigmoid激活函数,可以提高网络对多标签图片的预测能力,sigmoid激活函数的表达式如下:
[0094][0095]
上式中,z为网络(具体是教师网络或学生网络)最后一层的输入标签。
[0096]
s5、将教师网络的预测结果和学生网络的预测结果分别与真实标签进行比较,对预测出错的标签消化纠错。
[0097]
如图4所示为本实施例知识蒸馏方法的流程示意图,知识蒸馏方法过程包括:
[0098]
设教师网络预测标签最大值的下标索引为m1,学生网络预测标签最大值的下标索引为m2,真实标签最大值的下标索引为m3,若m1≠m3,则说明教师网络的预测标签出错,若m2≠m3,则说明学生网络的预测标签出错。
[0099]
对于预测出错的标签,则需要进行消化纠错,消化纠错的公式如下:
[0100][0101]
上式中,y为预测标签,α为随机数且0<α≤1。
[0102]
s6、通过知识蒸馏训练迭代,将步骤s5中得到的消化纠错结果与学生网络的预测结果进行损失计算,将计算得到的损失值进行反向传播,以此完成知识蒸馏训练过程。
[0103]
步骤s6中,损失计算的损失函数为余弦损失函数σ(x),其计算式如下:
[0104][0105]
上式中,x1为学生网络的预测标签,x2为经过消化纠错之后的预测标签;
[0106]
设学生网络的预测标签为ys,教师网络的预测标签y
t
,消化纠错之后学生网络的预测标签为消化纠错之后教师网络的预测标签为当前知识蒸馏训练迭代次数为k,ys,k代入到公式(6),求得最后的损失,则最后的损失函数计算如下:
[0107][0108]
为了保证本实验的客观性,本实施例优选将数据集互斥随机划分成训练集和测试集,划分比例为7:3,即训练集为70%和测试集为30%,分别进行知识蒸馏的训练和测试。
[0109]
表1是本实施例所述的注意力模型的知识蒸馏方法与常规的图片分类方法的对比,所选用的图片分类方法有resnet不同版本、efficientnet、vit不使用知识蒸馏的版本和deit不同版本。
[0110]
表1
[0111][0112]
如表1所示,本发明方法将vit注意力模型的精度从原来的70.3%提升到88.71%,同时比知识蒸馏的教师网络(resnet-50)87.78%高出约1个点,具有更好的鲁棒性和更高的准确率。本发明学生网络知识蒸馏准确度和损失值曲线变化分别如图5的图5(a)和图5(b)所示。
[0113]
以上仅描述了本发明的基本原理和优选实施方式,本领域人员可以根据上述描述做出许多变化和改进,这些变化和改进应该属于本发明的保护范围。
再多了解一些

本文用于创业者技术爱好者查询,仅供学习研究,如用于商业用途,请联系技术所有人。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献