一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

实时生成对抗样本的深度学习训练数据增广方法、装置、电子设备及介质与流程

2021-10-24 03:55:00 来源:中国专利 TAG:深度 训练 增广 学习 电子设备

技术特征:
1.一种实时生成对抗样本的深度学习训练数据增广方法,其特征在于,包括以下步骤:s1:将图像样本与随机噪声输入至对抗样本训练网络,所述对抗样本训练网络生成对抗样本;s2:将对抗样本输入至经由正常训练过程训练好的深度学习网络中;s3:根据深度学习网络的输出和标签计算第一损失函数,根据深度学习网络的输出和经标签混淆后的标签计算第二损失函数;s4:利用对抗优化器对第二损失函数进行梯度回传和对抗样本训练网络参数更新操作,同时对第一损失函数进行一次梯度回传操作并进行记录与累加;s5:重复步骤s1至s4,直至对抗样本训练网络生成的对抗样本为可用于优化深度学习网络的数据增广图片或深度学习网络充分学习到该图像样本中物体的特征;s6:由训练优化器将第一损失函数梯度回传时记录与累加得到的参数集合a对深度学习网络进行一次参数更新;s7:重复步骤s1至s6,当训练集中每一图像样本至少重复上述步骤一次后,判断是否达到终止训练条件;s8:达到终止训练条件,终止训练。2.根据权利要求1所述的实时生成对抗样本的深度学习训练数据增广方法,其特征在于,所述对抗样本训练网络设置于所述深度学习网络的输入层前面,为一个对抗参数层,所述对抗参数层的形状为b x h x w x c,其中b为训练的一个批次的图像数量,h与w对应输入的训练图像的高和宽,c为输入图像的通道数,在所述对抗参数层中,输入图像首先与对抗参数层参数相加,输出截取0至255的值,再经过归一化,转为浮点数后输出,此时对抗参数层的参数集合为g。3.根据权利要求2所述的实时生成对抗样本的深度学习训练数据增广方法,其特征在于,步骤s3中根据深度学习网络的输出和经标签混淆后的标签计算第二损失函数中,根据深度学习网络的训练任务不同,标签混淆的操作方法包括:对于单分类任务可以将one

hot做随机循环移位操作;对于多分类任务可以直接对multi

hot向量取反,或者全部置零;对于分割任务可以将标签掩码置零,或者做随机类别变换;对于文本识别任务可以将其中的字符变成形状相似的字符。4.根据权利要求3所述的实时生成对抗样本的深度学习训练数据增广方法,其特征在于,所述步骤s4中优化对抗器采用sgd并配以0动量。5.根据权利要求4所述的实时生成对抗样本的深度学习训练数据增广方法,其特征在于,步骤s5中直至对抗样本训练网络生成的对抗样本为可用于优化深度学习网络的数据增广图片或深度学习网络充分学习到该图像样本中物体的特征,具体为:若原始图像经过对抗样本训练网络后生成的对抗样本,经过深度学习网络以后被误识别为经标签混淆后的标签,则表示生成的对抗样本为可用于优化深度学习网络的数据增广图片;若重复了一定次数以后,第二损失函数不再发生显著变化,,则表示深度学习网络充分学习到该图像样本中物体的特征。6.根据权利要求5所述的实时生成对抗样本的深度学习训练数据增广方法,其特征在
于,所述步骤s7中训练集中每一图像样本至少重复上述步骤一次,应使简单样本出现的概率小,而难样本出现的概率大,所述简单样本指经过网络前向后,输出的结果极为接近标签,对网络训练起不到较大作用的样本,所述难样本为经过数个周期的训练后,网络仍然不能很好识别的样本。7.根据权利要求6所述的实时生成对抗样本的深度学习训练数据增广方法,其特征在于,所述步骤s7中终止训练条件,具体包括:损失函数长时间不降低;验证集准确率长时间无明显提升;设定固定的数个周期以后停止;每个周期结束时,对验证集中的每张图像做一次对抗扰动和识别,统计其中简单样本的比例,当简单样本比例达到指定条件时,终止训练;验证集中的样本对抗后输出到屏幕显示,当人工观察发现对抗样本在人看来无法识别时,终止训练。8.一种实时生成对抗样本的深度学习训练数据增广装置,其特征在于,包括:对抗样本生成模块,所述输入模块用于将图像样本与随机噪声输入至对抗样本训练网络,所述对抗样本训练网络生成对抗样本;输入模块,所述输入模块用于将对抗样本输入至经由正常训练过程训练好的深度学习网络中;损失函数计算模块,所述损失函数计算模块根据深度学习网络的输出和标签计算第一损失函数,根据深度学习网络的输出和经标签混淆后的标签计算第二损失函数;梯度回传模块,所述梯度回传模块利用对抗优化器对第二损失函数进行梯度回传和对抗样本训练网络参数更新操作,同时对第一损失函数进行一次梯度回传操作并进行记录与累加;第一循环模块,所述第一循环模块用于将同一批次的图像样本重复经上述模块处理,直至对抗样本训练网络生成的对抗样本为可用于优化深度学习网络的数据增广图片或深度学习网络充分学习到该图像样本中物体的特征;深度学习网络参数更新模块,所述深度学习网络参数更新模块由训练优化器将第一损失函数梯度回传时记录与累加得到的参数集合a对深度学习网络进行一次参数更新;第二循环模块,所述第二循环模块用于将训练集中每一图像样本至少重复经上述模块处理一次后,判断是否达到终止训练条件;终止模块,所述终止模块用于在达到终止训练条件时,终止训练。9.一种电子设备,包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,其特征在于,所述处理器运行所述计算机程序时执行以实现如权利要求1至7任一项所述的方法。10.一种计算机可读介质,其特征在于,其上存储有计算机可读指令,所述计算机可读指令可被处理器执行以实现如权利要求1至7任一项所述的方法。

技术总结
本发明提供一种实时生成对抗样本的深度学习训练数据增广方法、装置、电子设备及介质,本发明使用网络梯度回传修改输入图片,生成对抗样本,并实时用生成的对抗样本来训练深度学习网络;使用两个优化器,在训练中分别对对抗样本训练网络和深度学习网络参数进行优化,以达到一次循环迭代中就可以同时优化对抗参数和网络参数的效果,加速了对抗与训练,且只需要在原有的网络结构的基础上,增加一个对抗参数层,一个损失函数,以及一个优化器即可;有效提升训练出来的深度学习模型鲁棒性,提升模型在实际应用时的精度和召回,能有效避免模型在未知数据上出现的不可解释的误判现象。未知数据上出现的不可解释的误判现象。未知数据上出现的不可解释的误判现象。


技术研发人员:邓亮 刁艺琦
受保护的技术使用者:广东杰纳医药科技有限公司
技术研发日:2021.07.12
技术公布日:2021/10/23
再多了解一些

本文用于企业家、创业者技术爱好者查询,结果仅供参考。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献

  • 日榜
  • 周榜
  • 月榜