一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

一种目标检测模型训练方法、装置、设备和存储介质与流程

2022-11-19 16:43:22 来源:中国专利 TAG:


1.本发明涉及人工智能技术领域,尤其涉及一种目标检测模型训练方法、装置、设备和存储介质。


背景技术:

2.随着神经网络在目标检测领域的迅猛发展,目标检测模型的功能越来越强大。这些目标检测模型往往采用深度神经网络结构,深度神经网络结构会消耗大量的存储以及计算资源,并且目标检测耗时也不断增加。
3.为克服上述问题,蒸馏算法应运而生。蒸馏方法是一种模型压缩方法,其主要思想是利用已经训练完成的教师网络去辅助一个资源消耗较小的学生网络的训练,以便于在目标检测过程中减少资源消耗,且达相同的目标检测效果。
4.目前在学生网络训练过程中,瓶颈层的参数调节通常仅考虑学生模型和教师模型前向传播过程中特征值的差异,学生模型学习到的信息有限,使得训练后的学生模型在目标检测时无法达到预期效果。


技术实现要素:

5.本发明提供了一种目标检测模型训练方法、装置、设备和存储介质,以解决基于蒸馏算法训练得到的目标检测模型无法达到预期效果的问题。
6.根据本发明的一方面,提供了一种目标检测模型训练方法,包括:
7.将样本图像输入至已训练第一模型和原始第二模型,获取已训练第一模型瓶颈层提取的第一特征值和检测层输出的第一检测结果,以及原始第二模型瓶颈层提取的第二特征值和检测层输出第二检测结果;
8.依据所述第一特征值、第二特征值、第一检测结果和第二检测结果,确定原始第二模型的瓶颈层蒸馏损失值;
9.依据所述第一检测结果和第二检测结果,确定原始第二模型的检测层蒸馏损失值;
10.基于所述瓶颈层蒸馏损失值和检测层蒸馏损失值,对所述原始第二模型进行训练,将训练后的原始第二模型作为目标检测模型。
11.根据本发明的另一方面,提供了一种目标检测模型训练装置,包括:
12.特征值获取模块,用于将样本图像输入至已训练第一模型和原始第二模型,获取已训练第一模型瓶颈层提取的第一特征值和检测层输出的第一检测结果,以及原始第二模型瓶颈层提取的第二特征值和检测层输出第二检测结果;
13.瓶颈层损失值确定模块,用于依据所述第一特征值、第二特征值、第一检测结果和第二检测结果,确定原始第二模型的瓶颈层蒸馏损失值;
14.检测层损失值确定模块,用于依据所述第一检测结果和第二检测结果,确定原始第二模型的检测层蒸馏损失值;
15.目标检测模型训练模块,用于基于所述瓶颈层蒸馏损失值和检测层蒸馏损失值,对所述原始第二模型进行训练,将训练后的原始第二模型作为目标检测模型。
16.根据本发明的另一方面,提供了一种电子设备,所述电子设备包括:
17.至少一个处理器;以及
18.与所述至少一个处理器通信连接的存储器;其中,
19.所述存储器存储有可被所述至少一个处理器执行的计算机程序,所述计算机程序被所述至少一个处理器执行,以使所述至少一个处理器能够执行本发明任一实施例所述的目标检测模型训练方法。
20.根据本发明的另一方面,提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使处理器执行时实现本发明任一实施例所述的目标检测模型训练方法。
21.本发明实施例的技术方案,首先将样本图像输入至已训练第一模型和原始第二模型,获取已训练第一模型瓶颈层提取的第一特征值和检测层输出的第一检测结果,以及原始第二模型瓶颈层提取的第二特征值和检测层输出第二检测结果,进而依据第一特征值、第二特征值、第一检测结果和第二检测结果,确定原始第二模型的瓶颈层蒸馏损失值,并依据第一检测结果和第二检测结果,确定原始第二模型的检测层蒸馏损失值,最终基于瓶颈层蒸馏损失值和检测层蒸馏损失值,对原始第二模型进行训练,将训练后的原始第二模型作为目标检测模型。在训练过程中,将检测结果反向映射至瓶颈层确定瓶颈层蒸馏损失值,可以同时考虑瓶颈层的特征值差异以及每个位置预测结果差异,优化模型训练效果。
22.应当理解,本部分所描述的内容并非旨在标识本发明的实施例的关键或重要特征,也不用于限制本发明的范围。本发明的其它特征将通过以下的说明书而变得容易理解。
附图说明
23.为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
24.图1是根据本发明实施例提供的一种目标检测模型训练方法的流程图;
25.图2是根据本发明实施例提供的一种目标检测模型训练方法的流程图;
26.图3是根据本发明实施例提供的一种目标检测模型训练方法的流程图;
27.图4是根据本发明实施例提供的一种目标检测模型训练装置的结构示意图;
28.图5是实现本发明实施例的目标检测模型训练方法的电子设备的结构示意图。
具体实施方式
29.为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
30.需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本发明的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
31.图1为本发明实施例提供了一种目标检测模型训练方法的流程图,本实施例可适用于基于知识蒸馏算法进行目标检测模型训练的情况,该方法可以由目标检测模型训练装置来执行,该目标检测模型训练装置可以采用硬件和/或软件的形式实现,该目标检测模型训练装置可配置于各种通用计算设备中。如图1所示,该方法包括:
32.s110、将样本图像输入至已训练第一模型和原始第二模型,获取已训练第一模型瓶颈层提取的第一特征值和检测层输出的第一检测结果,以及原始第二模型瓶颈层提取的第二特征值和检测层输出第二检测结果。
33.基于知识蒸馏的模型训练方法中,是用结构简单的小网络模型去模仿结构复杂的大网络模型,使得小网络模型能够接近于大网络模型的性能。具体的,首先采用训练数据和预先设定的数据标签,对大网络模型进行训练,得到训练后的大网络模型。进一步的,将训练数据同时输入至训练后的大网络模型和小网络模型,由大网络模型的输出结果作为软标签,将预先设定的数据标签作为硬标签,分别构建小网络模型基于软标签的损失函数以及基于硬标签的损失函数。最终依据软标签损失函数和硬标签损失函数构成总损失函数,对小网络模型进行训练,最终将训练后的小网络模型作为目标检测模型。
34.本发明实施例中的已训练第一模型即为训练后的大网络模型,也称教师模型,原始第二模型即待训练的小网络模型,也称学生模型。其中,已训练第一模型和原始第二模型均包括主干网络层(backbone layer)、瓶颈层(neck layer)以及检测层(head layer)。其中,主干网络层用于对输入数据(例如,图像)进行特征提取,瓶颈层位于主干网络层和检测层之间,用于进行特征融合,检测层用于为输入数据中的目标进行定位和分类。
35.本发明实施例中,将样本图像输入至已训练第一模型,获取已训练第一模型在前向传播过程中在瓶颈层提取的样本图像的第一特征值,以及在检测层输出的针对样本图像的第一检测结果。同时,将样本图像输入至原始第二模型,获取原始第二模型在前向传播过程中在瓶颈层提取的样本图像的第二特征值,以及在检测层输出的针对样本图像的第二检测结果。
36.其中,样本图像可以是包含一个或多个对象的图像,第一特征值和第二特征值分别是已训练第一模型和原始第二模型的瓶颈层对样本图像进行特征提取得到的特征数据。第一检测结果和第二检测结果分别是已训练第一模型和原始第二模型的检测层输出的分类结果。示例性的,第一检测结果和第二检测结果中均可以包括针对样本图像的预测框所在位置、预测框的预测类别以及对应置信度,除此之外,还可以包括每个预测框属于每个类别的概率。
37.s120、依据第一特征值、第二特征值、第一检测结果和第二检测结果,确定原始第二模型的瓶颈层蒸馏损失值。
38.现有技术中通常采用计算第一特征值和第二特征值之间差异的方式,确定瓶颈层蒸馏损失值。这种方式仅考虑了前向传播已训练第一模型和原始第二模型瓶颈层输出的每个特征值之间的差异,并没有考虑到每个位置特征值检测结果差异,也即没有挖掘每个位置语义上的信息。
39.本发明实施例中,为了克服上述问题,通过第一特征值、第二特征值、第一检测结果和第二检测结果,共同确定原始第二模型的瓶颈层蒸馏损失值。使得瓶颈层蒸馏损失值除了包括特征值之间的差异,还包括每个位置上的检测结果差异,也就是类别差异,使原始第二模型学习到更多的语义信息。
40.具体的,首先将第一检测结果中每个第一预测框属于每个类别的概率映射至已训练第一模型的瓶颈层,得到瓶颈层每个位置属于每个类别的概率,作为第一概率映射值。同理,将第二检测结果中每个第二预测框属于每个类别的概率映射至原始第二模型的瓶颈层,得到瓶颈层每个位置属于每个类别的概率,作为第二概率映射值。进一步的,依据第一映射概率值和第二映射概率值,计算原始第二模型在瓶颈层每个位置的类别损失值,并依据第一特征值和第二特征值,计算原始第二模型在瓶颈层每个位置的特征损失值。最终,由类别损失值和特征损失值共同构成瓶颈层每个位置的综合损失,并依据每个位置的综合损失计算瓶颈层蒸馏损失值。
41.通过将第一检测结果和第二检测结果反向映射至各自对应模型的瓶颈层,可以依据瓶颈层每个位置上的类别差异和特征差异,得到瓶颈层蒸馏损失值,使原始第二模型的瓶颈层能够学习到类别语义信息,优化原始第二模型的训练效果。
42.s130、依据第一检测结果和第二检测结果,确定原始第二模型的检测层蒸馏损失值。
43.本发明实施例中,在确定瓶颈层蒸馏损失值后,进一步依据第一检测结果和第二检测结果,计算原始第二模型的检测层蒸馏损失值。具体的,首先获取第一检测结果中每个第一检测框的位置和第二检测结果中每个第二检测框的位置。进一步的,针对每个第一检测框,确定与当前第一检测框对应的第二检测框,并基于当前第一检测框的位置和对应第二检测框的位置,计算当前第一检测框与对应第二检测框的交并比,依据交并比确定第二检测框的位置损失值。进一步的,在第一检测结果中获取每个第一检测框的置信度,并在第二检测结果中获取每个第二检测框的置信度,进而依据当前第一检测框的置信度和对应第二检测框的置信度,计算当前第二检测框的置信度损失值。最终,基于位置损失值和置信度损失值,共同确定原始第二模型的检测层蒸馏损失值。
44.由检测层输出的每个预测框的位置信息,计算位置损失值,并依据每个预测框的置信度,计算置信度损失值,最终由位置损失值和置信度损失值共同确定检测层蒸馏损失值,可以综合考虑检测框的位置边界以及检测框的置信度,提高原始第二模型训练效果。
45.s140、基于瓶颈层蒸馏损失值和检测层蒸馏损失值,对原始第二模型进行训练,将训练后的原始第二模型作为目标检测模型。
46.本发明实施例中,基于上述步骤中确定的瓶颈层蒸馏损失值和检测层蒸馏损失值,共同对原始第二模型进行训练,并将训练后的原始第二模型作为目标检测模型。具体的,原始第二模型的总损值可以由瓶颈层蒸馏损失值、检测层蒸馏损失值、人工标签分类损失以及人工标签回归损失共同确定。总损失值loss
sum_student
计算公式如下:
47.loss
sum_student
=loss
class_tag
loss
bbox_tag
loss
预测框
loss
fpn
48.其中,loss
class_tag
是人工标签分类损失值,loss
bbox_tag
是人工标签回归损失值,loss
预测框
是检测层蒸馏损失值,loss
fpn
是瓶颈层蒸馏损失值。
49.最终训练得到的目标检测模型可以用于对输入图像进行目标检测:将待检测图像输入至目标检测模型,利用目标检测模型输出针对待检测图像的检测结果。其中,检测结果包括待检测图像中的预测框、预测框中包含对象的类别,以及针对该类别的置信度。
50.本发明实施例的技术方案,首先将样本图像输入至已训练第一模型和原始第二模型,获取已训练第一模型瓶颈层提取的第一特征值和检测层输出的第一检测结果,以及原始第二模型瓶颈层提取的第二特征值和检测层输出第二检测结果,进而依据第一特征值、第二特征值、第一检测结果和第二检测结果,确定原始第二模型的瓶颈层蒸馏损失值,并依据第一检测结果和第二检测结果,确定原始第二模型的检测层蒸馏损失值,最终基于瓶颈层蒸馏损失值和检测层蒸馏损失值,对原始第二模型进行训练,将训练后的原始第二模型作为目标检测模型。在训练过程中,将检测结果反向映射至瓶颈层确定瓶颈层蒸馏损失值,可以同时考虑瓶颈层的特征值差异以及每个位置预测结果差异,优化模型训练效果。
51.可选的,在完成目标检测模型训练之后,还包括:
52.获取待检测图像;
53.将待检测图像输入至目标检测模型中,基于目标检测模型的输出确定待检测图像中的至少一个目标;目标检测模型基于本发明任一实施例所述的目标检测模型训练方法训练得到。
54.本可选的实施例中,目标检测模型由本发明任一实施例提供的目标检测模型训练方法训练得到。待检测图像可以是包含至少一个目标的图像,将待检测图像输入至训练好的目标检测模型,目标检测模型可以输出针对待检测图像中指定目标的检测结果。其中,检测结果中可以包括目标所在预测框、目标所述的类型信息以及置信度等。
55.图2为本发明实施例提供的一种目标检测模型训练方法的流程图,本实施例在上述实施例的基础上进一步细化,提供了依据第一特征值、第二特征值、第一检测结果和第二检测结果,确定原始第二模型的瓶颈层蒸馏损失值的具体步骤。如图2所示,该方法包括:
56.s210、将样本图像输入至已训练第一模型和原始第二模型,获取已训练第一模型瓶颈层提取的第一特征值和检测层输出的第一检测结果,以及原始第二模型瓶颈层提取的第二特征值和检测层输出第二检测结果。
57.s220、将第一检测结果中第一预测框的类别概率映射至已训练第一模型的瓶颈层,得到第一概率映射值,将第二检测结果中第二预测框的类别概率映射至原始第二模型的瓶颈层,得到第二概率映射值。
58.本发明实施例中,将第一检测结果中第一预测框属于每个类别的概率映射至已训练第一模型的瓶颈层,得到瓶颈层每个位置属于每个类别的概率作为第一概率映射值。同理,将第二检测结果中第二预测框属于每个类别的概率映射至原始第二模型的瓶颈层,得到瓶颈层每个位置属于每个类别的概率作为第二映射概率值。
59.在一个具体的例子中,已训练第一模型的瓶颈层的宽度为weight,高度为height。将第一检测结果中每个第一预测框属于每个类别的概率映射至已训练第一模型的瓶颈层,可以得到瓶颈层每个位置(i,j)属于每个类别的概率,例如,总类别数量为10,则瓶颈层每
个位置对应10个类别概率。
60.s230、依据第一概率映射值和第二概率映射值,确定原始第二模型在瓶颈层每个位置的类别损失值。
61.本发明实施例中,在将第一检测结果和第二检测结果映射至瓶颈层得到第一概率映射值和第二概率映射值之后,依据第一概率映射值和第二概率映射值确定原始第二模型在瓶颈层每个位置处的类别损失值。具体的,可以基于第一映射概率值和第二映射概率值,采用kl散度(kullback-leible divergence)计算方式,计算原始第二模型在瓶颈层每个位置的类别损失值。针对瓶颈层的位置(i,j)的类别损失值的计算公式如下:
[0062][0063]
其中,i是瓶颈层的横轴坐标,j是瓶颈层的纵轴坐标,c是类别总数量,是已训练第一模型瓶颈层的位置(i,j)属于第z类别的概率,是原始第二模型瓶颈层的位置(i,j)属于第z类别的概率。
[0064]
s240、依据第一特征值和第二特征值之间的特征差异,确定原始第二模型在瓶颈层每个位置的特征损失值。
[0065]
本发明实施例中,计算已训练第一模型瓶颈层输出的第一特征值和原始第二模型瓶颈层输出的第二特征值之间的差异,并基于该差异确定原始第二模型在瓶颈层每个位置的特征损失值。
[0066]
示例性的,已训练第一模型瓶颈层位置(i,j)对应的第一特征值是原始第二模型瓶颈层位置(i,j)对应的第二特征值是则该位置处的特征损失值为
[0067]
s250、基于类别损失值和特征损失值,确定原始第二模型的瓶颈层蒸馏损失值。
[0068]
本发明实施例中,在分别获取到原始第二模型在瓶颈层每个位置的类别损失值和特征损失值之后,基于类别损失值和特征损失值,确定原始第二模型的瓶颈层蒸馏损失。具体的,可以将每个位置处的类别损失值和特征损失值相加,得到每个位置的综合损失值,最终将全部位置的综合损失值相加,得到瓶颈层蒸馏损失。具体的,还可以是将瓶颈层每个位置处的类别损失值和特征损失值相加,得到每个位置的综合损失值,进而依据标定掩膜矩阵,将瓶颈层位置分为前景位置和背景位置,分别计算前景位置的损失均值和背景位置的损失均值,最终将前景位置的损失均值和背景位置的损失均值求和,得到瓶颈层蒸馏损失。其中,前景位置均值可以依据前景位置的综合损失值、前景位置数量以及前景位置对应的难例权重计算得到,难例权重可以依据该前景位置处原始第二模型输出的置信度确定,置信度越高,难例权重越低。
[0069]
可选的,基于类别损失值和特征损失值,确定原始第二模型的瓶颈层蒸馏损失值,包括:
[0070]
基于类别损失值和特征损失值,确定原始第二模型在瓶颈层每个位置的综合损失值;
[0071]
依据综合损失值以及标定掩膜矩阵,确定瓶颈层的前景损失均值和背景损失均值;标定掩膜矩阵依据预先标定的前景标签和背景标签确定;
[0072]
由前景损失均值和背景损失均值,构成原始第二模型的瓶颈层蒸馏损失值。
[0073]
本可选的实施例中,提供一种基于类别损失值和特征损失值,确定原始第二模型的瓶颈层蒸馏损失值的具体方式:首先,基于瓶颈层每个位置的类别损失值和特征损失值,确定原始第二模型在瓶颈层每个位置的综合损失值,具体为,可以对每个位置的类别损失值和特征损失值求和,将求和结果作为该位置的综合损失值。由于瓶颈层中的位置包含前景所在位置和背景所在位置,为了降低背景参数的影响,可以在计算综合损失值时,对前景位置和背景位置采用不同的计算方式:当前位置是前景位置时,直接将该位置的类别损失值和特征损失值求和,得到该位置的综合损失值;当前位置是背景位置时,可以通过超参对该位置的特征损失值进行调整,将调整后的特征损失值和背景损失值求和,得到该位置的综合损失值。针对瓶颈层位置(i,j)的综合损失值feature_loss
i,j
计算公式如下:
[0074][0075][0076]
其中,是原始第二模型在位置(i,j)的类别损失值,m_teacher
i,j
是基于已训练第一模型的第一预测结果确定的掩膜矩阵,是已训练第一模型在(i,j)位置的第一特征值,是原始第二模型在(i,j)位置的第二特征值,是背景权重。
[0077]
在计算得到瓶颈层每个位置的综合损失值之后,依据标定掩膜矩阵在综合损失值中确定属于前景位置的前景损失均值,以及属于背景位置的背景损失均值。最终由前景损失均值和背景损失均值求和,等得到原始第二模型的瓶颈层蒸馏损失值。其中,标定掩膜矩阵是依据人工标定的前景标签和背景标签确定,标定掩膜矩阵中属于前景的位置数值为1,属于背景的位置数值为0。
[0078]
在计算瓶颈层蒸馏损失值时,综合考虑前景位置以及背景位置的综合损失值,一方面,可以平衡前景与背景蒸馏损失对原始第二模型的影响,另一方面,可以使得所有背景信息参加蒸馏损失计算,使得原始第二模型学习到更多背景信息,优化模型训练效果。
[0079]
可选的,依据综合损失值以及标定掩膜矩阵,确定瓶颈层的前景损失均值,包括:
[0080]
依据标定掩膜矩阵,在瓶颈层每个位置的综合损失值中,确定每个前景位置的综合损失值;
[0081]
将第二检测结果中第二预测框的置信度映射至原始第二模型的瓶颈层,得到第二置信度映射值,并依据所述第二置信度映射值,确定瓶颈层中每个前景位置的难例权重;第二置信度映射值越高的前景位置,对应难例权重越小;
[0082]
依据难例权重,计算瓶颈层中前景位置对应综合损失值的加权平均值,并将加权平均值作为瓶颈层的前景损失均值。
[0083]
本可选的实施例中,提供了一种依据综合损失值以及标定掩膜矩阵,确定瓶颈层的前景损失均值的具体方式:首先,依据标定掩膜矩阵,在瓶颈层每个位置的综合损失值中,确定每个前景位置的综合损失值。进一步的,将第二检测结果中第二预测框的置信度映射至原始第二模型的瓶颈层,得到第二置信度映射值,并基于第二置信度映射值,确定瓶颈层中每个前景位置的难例权重,其中,第二置信度映射值越高的前景位置,对应难例权重越
小。最终依据难例权重,将瓶颈层中每个前景位置的综合损失值进行加权求和,并对加权求和的结果求取平均值,得到瓶颈层的前景损失均值。通过第二检测结果中第二预测框的置信度来确定瓶颈层每个前景位置的难例权重,可以在瓶颈层蒸馏学习时,通过难例和简单特征值的区分,使得原始第二模型能够提高对于难例的学习权重,针对自身缺陷进行针对性学习。
[0084]
可选的,背景损失均值可以是直接计算瓶颈层中全部背景位置的综合损失值的平均值。
[0085]
最终前景损失均值和背景损失均值求和,得到原始第二模型的瓶颈层蒸馏损失值loss
fpn
,具体计算公式如下:
[0086][0087][0088]
其中,num
背景
是原始第二模型瓶颈层背景的数量,num
前景
是原始滴入模型瓶颈层前景的数量,weight是原始第二模型瓶颈层的宽度,height是原始第二模型瓶颈层的高度,m
i,j
是基于人工标签确定的掩膜矩阵,是原始第二模型在(i,j)位置的置信度,feature_loss
i,j
是原始第二模型在(i,j)位置的综合损失值。
[0089]
s260、依据第一检测结果和第二检测结果,确定原始第二模型的检测层蒸馏损失值。
[0090]
s270、基于瓶颈层蒸馏损失值和检测层蒸馏损失值,对原始第二模型进行训练,将训练后的原始第二模型作为目标检测模型。
[0091]
本发明实施例的技术方案,将检测结果映射到瓶颈层,进而依据瓶颈层中每个位置对应特征值之间的差异,以及每个位置对应检测结果之间的差异来确定瓶颈层蒸馏损失值,最终基于瓶颈层蒸馏损失值和检测层蒸馏损失值对原始第二模型进行训练,可以使得原始第二模型学习到除特征值之外的更多语义信息,提升模型训练效果。
[0092]
图3为本发明实施例提供的一种目标检测模型训练方法的流程图,本实施例在上述实施例的基础上进一步细化,提供了依据第一检测结果和第二检测结果,确定原始第二模型的检测层蒸馏损失值的具体步骤。如图3所示,该方法包括:
[0093]
s310、将样本图像输入至已训练第一模型和原始第二模型,获取已训练第一模型瓶颈层提取的第一特征值和检测层输出的第一检测结果,以及原始第二模型瓶颈层提取的第二特征值和检测层输出第二检测结果。
[0094]
s320、依据第一特征值、第二特征值、第一检测结果和第二检测结果,确定原始第二模型的瓶颈层蒸馏损失值。
[0095]
s330、依据第一预测结果中第一预测框的位置和第二预测结果中的第二预测框的位置,确定每个第二预测框的位置损失值。
[0096]
交并比是进行目标检测时常用的检测效果衡量指标。本发明实施例中,依据第一预测结果中第一预测框的位置和第二预测结果中第二预测框的位置,确定每个第二预测框
的位置损失值。具体的,首先在第一检测结果包含的至少一个预测框中,分别确定与第二预测结果中每个第二预测框对应的第一预测框。进一步的,计算每个第二检测框以及与其对应的第一检测框的交并比,最终依据交并比确定每个第二检测框的位置损失值。其中,交并比越大,对应的位置损失值越小。第二检测结果中第i个第二预测框的位置损失值loss_bboxi的计算公式如下:
[0097][0098]
其中,是第i个第二预测框,是与第i个第二预测框对应的第一预测框。
[0099]
s340、依据第一预测结果中第一预测框的置信度,以及第二预测结果中第二预测框的置信度,确定每个第二预测框的置信度损失值。
[0100]
本发明实施例中,除了计算原始第二模型输出的每个第二预测框的位置损失值,还需要进一步依据第一预测结果中第一预测框的置信度,以及第二预测结果中第二预测框的置信度,计算每个第二预测框的置信度损失值。
[0101]
可选的,依据第一预测结果中第一预测框的置信度,以及第二预测结果中第二预测框的置信度,确定每个第二预测框的置信度损失值,包括:
[0102]
依据第一预测结果中第一预测框的置信度,确定第一预测框的置信度损失权重;第一预测框的置信度越高,置信度损失权重越小;
[0103]
在第二预测框属于前景的情况下,依据前景超参、置信度损失权重、第二预测框的置信度以及对应第一预测框的置信度,确定第二预测框的置信度损失值;
[0104]
在第二预测框属于背景的情况下,依据背景超参、置信度损失权重、第二预测框的置信度以及对应第一预测框的置信度,确定第二预测框的置信度损失值。
[0105]
本可选的实施例中,提供一种依据第一预测结果中第一预测框的置信度,以及第二预测结果中第二预测框的置信度,确定每个第二预测框的置信度损失值的具体方式:首先,依据第一预测结果中第一预测框的置信度,确定第一预测框的置信度损失权重,其中,第一预测框的置信度越高,对应的置信度损失权重越小。进一步的,基于置信度损失权重计算每个第二预测框的置信度损失值:在第二预测框属于前景的情况下,依据预先设定的前景超参、当前第二预测框对应的第一预测框的置信度损失权重、第二预测框的置信度以及对应第一预测框的置信度,确定第二预测框的置信度损失值;在第二预测框属于背景的情况下,依据预先设定的背景超参、当前第二预测框对应的第一预测框的置信度损失权重、第二预测框的置信度以及对应第一预测框的置信度,确定第二预测框的置信度损失值。
[0106]
具体的,针对第二预测框i的置信度损失值loss_pi的计算公式如下:
[0107][0108]
其中,是第i个第二预测框的置信度,是与第i个第二预测框对应的第一预测框的置信度,和r是超参,用于调节前景和背景的权重。用于调试难例样本,降低简单样本权重,提高困难样本权重,使得原始第二模型能针对自身缺陷进行优化。
[0109]
s350、基于位置损失值和置信度损失值,确定原始第二模型的检测层蒸馏损失值。
[0110]
本发明实施例中,在计算得到原始第二模型的检测层每个第二预测框的位置损失值和置信度损失值之后,可以对位置损失值和置信度损失值求和得到每个第二预测框的综
合损失值,最终将全部第二预测框的综合损失值求和得到原始第二模型的检测层蒸馏损失值。
[0111]
还可以是计算与每个第二预测框对应的第一预测框,与对应标注框的交并比,来确定每个第一预测框的权重,其中,第一预测框与标注框的交并比越大,该第一预测框的权重越高,同时,第一预测框的置信度越高,该第一预测框的权重越高。最终由第一预测框的权重来调节与其对应的第二预测框的位置损失值,避免在已训练第一模型输出的预测框置信度较低的情况下,原始第二模型学习到已训练第一模型的错误黑盒知识。
[0112]
可选的,基于位置损失值和置信度损失值,确定原始第二模型的检测层蒸馏损失值,包括:
[0113]
基于第一检测结果中第一预测框位置、对应标注框的位置以及第一预测框的置信度,确定与每个第二预测框匹配的第一预测框的位置损失权重;
[0114]
通过位置损失权重对位置损失值进行处理,依据置信度损失值和处理后的位置损失值,确定原始第二模型的检测层蒸馏损失值。
[0115]
本可选的实施例中,提供一种基于位置损失值和置信度损失值,确定原始第二模型的检测层蒸馏损失值的具体方式:首先,基于第一检测结果中第一预测框的位置、对应标注框(人为标注框)以及第一预测框的置信度,确定与每个第二预测框匹配的第一预测框的位置损失权重。每个第二预测框匹配的第一预测框的位置损失权重weighti计算公式如下:
[0116][0117]
其中,是与第i个第二预测框对应的第一预测框的置信度,与第i个第二预测框对应的第一预测框,是与第i个第二预测框对应的标注框。
[0118]
在计算得到每个第二预测框对应的第一预测框的位置损失权重后,通过位置损失权重对相应的第二预测框的位置损失值进行处理,依据置信度损失值和处理后的位置损失值,确定原始第二模型的检测层蒸馏损失值。具体的,可以将每个第二预测框的置信度损失值和处理后的位置损失值求和,得到每个第二预测框的综合损失值,最终计算全部第二预测框的综合损失值之和,作为原始第二模型的检测层蒸馏损失值。检测层蒸馏损失值loss
预测框
的计算公式如下:
[0119][0120]
其中,s_sum是原始第二模型输出预测框的数量,weighti是第i个第二预测框对应的第一预测框的位置损失权重,loss_bboxi是第i个第二预测框的位置损失值,loss_pi是第i个第二预测框的置信度损失值。
[0121]
s360、基于瓶颈层蒸馏损失值和检测层蒸馏损失值,对原始第二模型进行训练,将训练后的原始第二模型作为目标检测模型。
[0122]
本发明实施例的技术方案,在确定检测层蒸馏损失值时,一方面依据预测结果中预测框的位置信息来确定位置损失值,另一方面依据检测框的置信度损失值确定置信度损失值,最终由置信度损失值和位置损失值共同确定检测层蒸馏损失值,可以优化模型训练效果。
[0123]
图4为本发明实施例提供的一种目标检测模型训练装置的结构示意图。如图4所
示,该装置包括:
[0124]
特征值获取模块410,用于将样本图像输入至已训练第一模型和原始第二模型,获取已训练第一模型瓶颈层提取的第一特征值和检测层输出的第一检测结果,以及原始第二模型瓶颈层提取的第二特征值和检测层输出第二检测结果;
[0125]
瓶颈层损失值确定模块420,用于依据所述第一特征值、第二特征值、第一检测结果和第二检测结果,确定原始第二模型的瓶颈层蒸馏损失值;
[0126]
检测层损失值确定模块430,用于依据所述第一检测结果和第二检测结果,确定原始第二模型的检测层蒸馏损失值;
[0127]
目标检测模型训练模块440,用于基于所述瓶颈层蒸馏损失值和检测层蒸馏损失值,对所述原始第二模型进行训练,将训练后的原始第二模型作为目标检测模型。
[0128]
本发明实施例的技术方案,首先将样本图像输入至已训练第一模型和原始第二模型,获取已训练第一模型瓶颈层提取的第一特征值和检测层输出的第一检测结果,以及原始第二模型瓶颈层提取的第二特征值和检测层输出第二检测结果,进而依据第一特征值、第二特征值、第一检测结果和第二检测结果,确定原始第二模型的瓶颈层蒸馏损失值,并依据第一检测结果和第二检测结果,确定原始第二模型的检测层蒸馏损失值,最终基于瓶颈层蒸馏损失值和检测层蒸馏损失值,对原始第二模型进行训练,将训练后的原始第二模型作为目标检测模型。在训练过程中,将检测结果反向映射至瓶颈层确定瓶颈层蒸馏损失值,可以同时考虑瓶颈层的特征值差异以及每个位置预测结果差异,优化模型训练效果。
[0129]
可选的,瓶颈层损失值确定模块420,包括:
[0130]
反向映射值确定单元,用于将所述第一检测结果中第一预测框的类别概率映射至已训练第一模型的瓶颈层,得到第一概率映射值,将所述第二检测结果中第二预测框的类别概率映射至原始第二模型的瓶颈层,得到第二概率映射值;
[0131]
类别损失值确定单元,用于依据所述第一概率映射值和第二概率映射值,确定原始第二模型在瓶颈层每个位置的类别损失值;
[0132]
特征损失值确定单元,用于依据所述第一特征值和所述第二特征值之间的特征差异,确定原始第二模型在瓶颈层每个位置的特征损失值;
[0133]
瓶颈层损失确定单元,用于基于所述类别损失值和特征损失值,确定原始第二模型的瓶颈层蒸馏损失值。
[0134]
可选的,瓶颈层损失确定单元,包括:
[0135]
综合损失值确定子单元,用于基于所述类别损失值和特征损失值,确定原始第二模型在瓶颈层每个位置的综合损失值;
[0136]
前景损失确定子单元,用于依据所述综合损失值以及标定掩膜矩阵,确定瓶颈层的前景损失均值和背景损失均值;所述标定掩膜矩阵依据预先标定的前景标签和背景标签确定;
[0137]
蒸馏层损失值确定子单元,用于由所述前景损失均值和背景损失均值,构成原始第二模型的瓶颈层蒸馏损失值。
[0138]
可选的,前景损失确定子单元,具体用于;
[0139]
依据所述标定掩膜矩阵,在所述瓶颈层每个位置的综合损失值中,确定每个前景位置的综合损失值;
[0140]
将所述第二检测结果中第二预测框的置信度映射至原始第二模型的瓶颈层,得到第二置信度映射值,并依据所述第二置信度映射值,确定瓶颈层中每个前景位置的难例权重;第二置信度映射值越高的前景位置,对应难例权重越小;
[0141]
依据所述难例权重,计算瓶颈层中前景位置对应综合损失值的加权平均值,并将所述加权平均值作为瓶颈层的前景损失均值。
[0142]
可选的,检测层损失值确定模块430,包括:
[0143]
位置损失值确定单元,用于依据第一预测结果中第一预测框的位置和第二预测结果中的第二预测框的位置,确定每个第二预测框的位置损失值;
[0144]
置信度损失值确定单元,用于依据第一预测结果中第一预测框的置信度,以及第二预测结果中第二预测框的置信度,确定每个第二预测框的置信度损失值;
[0145]
检测层损失值确定单元,用于基于所述位置损失值和置信度损失值,确定原始第二模型的检测层蒸馏损失值。
[0146]
可选的,检测层损失值确定单元,具体用于:
[0147]
基于第一检测结果中第一预测框位置、对应标注框的位置以及第一预测框的置信度,确定与每个第二预测框匹配的第一预测框的位置损失权重;
[0148]
通过所述位置损失权重对所述位置损失值进行处理,依据置信度损失值和处理后的位置损失值,确定原始第二模型的检测层蒸馏损失值。
[0149]
可选的,置信度损失值确定单元,具体用于:
[0150]
依据第一预测结果中第一预测框的置信度,确定第一预测框的置信度损失权重;所述第一预测框的置信度越高,置信度损失权重越小;
[0151]
在第二预测框属于前景的情况下,依据前景超参、置信度损失权重、第二预测框的置信度以及对应第一预测框的置信度,确定所述第二预测框的置信度损失值;
[0152]
在第二预测框属于背景的情况下,依据背景超参、置信度损失权重、第二预测框的置信度以及对应第一预测框的置信度,确定所述第二预测框的置信度损失值。
[0153]
本发明实施例所提供的目标检测模型训练装置可执行本发明任意实施例所提供的目标检测模型训练方法,具备执行方法相应的功能模块和有益效果。
[0154]
图5示出了可以用来实施本发明的实施例的电子设备10的结构示意图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备(如头盔、眼镜、手表等)和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本发明的实现。
[0155]
如图5所示,电子设备10包括至少一个处理器11,以及与至少一个处理器11通信连接的存储器,如只读存储器(rom)12、随机访问存储器(ram)13等,其中,存储器存储有可被至少一个处理器执行的计算机程序,处理器11可以根据存储在只读存储器(rom)12中的计算机程序或者从存储单元18加载到随机访问存储器(ram)13中的计算机程序,来执行各种适当的动作和处理。在ram 13中,还可存储电子设备10操作所需的各种程序和数据。处理器11、rom 12以及ram 13通过总线14彼此相连。输入/输出(i/o)接口15也连接至总线14。
[0156]
电子设备10中的多个部件连接至i/o接口15,包括:输入单元16,例如键盘、鼠标
等;输出单元17,例如各种类型的显示器、扬声器等;存储单元18,例如磁盘、光盘等;以及通信单元19,例如网卡、调制解调器、无线通信收发机等。通信单元19允许电子设备10通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
[0157]
处理器11可以是各种具有处理和计算能力的通用和/或专用处理组件。处理器11的一些示例包括但不限于中央处理单元(cpu)、图形处理单元(gpu)、各种专用的人工智能(ai)计算芯片、各种运行机器学习模型算法的处理器、数字信号处理器(dsp)、以及任何适当的处理器、控制器、微控制器等。处理器11执行上文所描述的各个方法和处理,例如目标检测模型训练方法。
[0158]
在一些实施例中,目标检测模型训练方法可被实现为计算机程序,其被有形地包含于计算机可读存储介质,例如存储单元18。在一些实施例中,计算机程序的部分或者全部可以经由rom 12和/或通信单元19而被载入和/或安装到电子设备10上。当计算机程序加载到ram 13并由处理器11执行时,可以执行上文描述的目标检测模型训练方法的一个或多个步骤。备选地,在其他实施例中,处理器11可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行目标检测模型训练方法。
[0159]
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、现场可编程门阵列(fpga)、专用集成电路(asic)、专用标准产品(assp)、芯片上系统的系统(soc)、复杂可编程逻辑设备(cpld)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
[0160]
用于实施本发明的方法的计算机程序可以采用一个或多个编程语言的任何组合来编写。这些计算机程序可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器,使得计算机程序当由处理器执行时使流程图和/或框图中所规定的功能/操作被实施。计算机程序可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
[0161]
在本发明的上下文中,计算机可读存储介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的计算机程序。计算机可读存储介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。备选地,计算机可读存储介质可以是机器可读信号介质。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(ram)、只读存储器(rom)、可擦除可编程只读存储器(eprom或快闪存储器)、光纤、便捷式紧凑盘只读存储器(cd-rom)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
[0162]
为了提供与用户的交互,可以在电子设备上实施此处描述的系统和技术,该电子设备具有:用于向用户显示信息的显示装置(例如,crt(阴极射线管)或者lcd(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给电子设备。其它种类的装置还可以用于提供与用户的交互;例如,提供给
用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
[0163]
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(lan)、广域网(wan)、区块链网络和互联网。
[0164]
计算系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品,以解决了传统物理主机与vps服务中,存在的管理难度大,业务扩展性弱的缺陷。
[0165]
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发明中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本发明的技术方案所期望的结果,本文在此不进行限制。
[0166]
上述具体实施方式,并不构成对本发明保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本发明的精神和原则之内所作的修改、等同替换和改进等,均应包含在本发明保护范围之内。
再多了解一些

本文用于创业者技术爱好者查询,仅供学习研究,如用于商业用途,请联系技术所有人。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献