一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

域适应图像分类网络的训练方法、系统、设备及存储介质

2022-04-14 01:30:03 来源:中国专利 TAG:


1.本发明涉及图像分类技术领域,尤其涉及一种域适应图像分类网络的训练方法、系统、设备及存储介质。


背景技术:

2.近年来,以深度神经网络为基础的全监督学习策略已在图像分类领域取得显著成就。这类全监督学习算法需要训练数据与测试数据分布相同。然而在实际应用中,训练(源域)数据与测试(目标域)数据往往存在差异。域适应方法旨在将源域知识迁移到目标域,以解决上述问题。
3.一般而言,分类模型需要在尽可能的聚类相同的语义特征的同时使其在特征空间中分布于分类权重附近。对于无监督及半监督域适应任务,由于目标数据缺乏监督信息,目标域特征难以按照语义聚类。而基于infonce损失的实例对比学习方法能够有效的在语义层面上聚集相似的特征,同时具有良好的可迁移性。但是,直接应用基于infonce损失的实例对比学习方法的增益非常有限,在一些具有强一致性约束的模型上甚至无明显增益。究其原因,因为之前的对比学习方法,普遍使用分类器之前的特征计算对比损失,分类权重没有参与到优化过程中,这类对比学习方法无法使特征分布于分类权重周围。因此,训练效果不佳,影响分类准确率。
4.在公开号为cn113673555a的中国专利申请《一种基于记忆体的无监督域适应图片分类方法》中,使用神经网络模型提取数据集中图片的特征,使用聚类算法构建每一类特征的类内结构,并将其存储到对应域的辅助记忆体中,以特征分布相似性作为约束条件迭代地训练模型。在公开号为cn113610105a的中国专利申请《基于动态加权学习和元学习的无监督域适应图像分类方法》中,将样本加权后构造动态平衡因子,分别计算源域和目标域数据分布对齐程度和可判别性并进行归一化处理,再使用元学习计算域对齐损失更新网络参数和分类进行模型优化。在公开号为cn113469273a的中国专利申请《基于双向生成及中间域对齐的无监督域适应图像分类方法》中,使用双向生成网络生成伪目标域图像和伪源域图像,生成过程中由任务网络提供监督信息提高生成图像的质量,完成训练后将伪图像输入分类网络,在此过程中不断缩减伪源域与源域图像的分布差异,从而提高分类网络的准确性。然而上述方法往往是通过添加额外的模块来提高分类模型的准确率,因此模型的参数量较大,训练时间较长,训练效率也受到一定限制。


技术实现要素:

5.本发明的目的是提供一种域适应图像分类网络的训练方法、系统、设备及存储介质,可以提升训练效率,提高图像分类的准确率。
6.本发明的目的是通过以下技术方案实现的:一种域适应图像分类网络的训练方法,包括:获取源域图像集合,并根据训练方式获取目标域图像集合;对目标域图像集合中
每一未标注的目标域图像进行两种不同的图像变换,获得的第一变换图像与第二变换图像构成一个目标域图像对,由所述源域图像集合、目标域图像集合以及所有目标域图像对构成训练数据集;将所述训练数据集输入至所述域适应图像分类网络;根据训练方式以及设定的损失类别,利用所述域适应图像分类网络特征提取器的输出、分类器的输出与softmax层的输出中的一种或多种计算相应的类别损失,构成所述域适应图像分类网络的基线损失;对于每一个目标域图像对,提取出所述域适应图像分类网络中softmax层的输出,构成一对概率向量;将每一对概率向量中第一变换图像对应的概率向量作为第一查询向量,将与第一查询向量不属于同一对概率向量中的其他所有概率向量作为相应第一查询向量的负样本,以及将每一对概率向量中第二变换图像对应的概率向量作为第二查询向量,将与第二查询向量不属于同一对概率向量中的其他所有概率向量作为相应第二查询向量的负样本;利用所有第一查询向量及对应的负样本,以及所有第二查询向量及对应的负样本计算总的概率对比损失;联合所述基线损失与总的概率对比损失训练所述域适应图像分类网络。
7.一种域适应图像分类网络的训练系统,包括:训练数据集构造单元,用于获取源域图像集合,并根据训练方式获取目标域图像集合;对目标域图像集合中每一未标注的目标域图像进行两种不同的图像变换,获得的第一变换图像与第二变换图像构成一个目标域图像对,由所述源域图像集合、目标域图像集合以及所有目标域图像对构成训练数据集;训练数据集输入单元,用于将所述训练数据集输入至所述域适应图像分类网络;基线损失计算单元,用于根据训练方式以及设定的损失类别,利用所述域适应图像分类网络特征提取器的输出、分类器的输出与softmax层的输出中的一种或多种计算相应的类别损失,构成所述域适应图像分类网络的基线损失;总的概率对比损失计算单元,对于每一个目标域图像对,提取出所述域适应图像分类网络中softmax层的输出,构成一对概率向量;将每一对概率向量中第一变换图像对应的概率向量作为第一查询向量,将与第一查询向量不属于同一对概率向量中的其他所有概率向量作为相应第一查询向量的负样本,以及将每一对概率向量中第二变换图像对应的概率向量作为第二查询向量,将与第二查询向量不属于同一对概率向量中的其他所有概率向量作为相应第二查询向量的负样本;利用所有第一查询向量及对应的负样本,以及所有第二查询向量及对应的负样本计算总的概率对比损失;训练单元,用于联合所述基线损失与总的概率对比损失训练所述域适应图像分类网络。
8.一种处理设备,包括:一个或多个处理器;存储器,用于存储一个或多个程序;其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器实现前述的方法。
9.一种可读存储介质,存储有计算机程序,其特征在于,当计算机程序被处理器执行时实现前述的方法。
10.由上述本发明提供的技术方案可以看出,在原有训练方式的基础上,引入对比学
习将相同语义的特征聚类,解决域适应图像分类任务在目标域标签不足的问题;本发明将特征对比学习改进为概率对比学习,通过在概率空间进行对比学习,减小聚类后的同语义特征与类权重之间的距离,提高分类的准确率;并且,仅添加了一个对比学习的损失(即总的概率对比损失),并未添加复杂的附加模块,参数量与之前的方法相比没有增加。总体来说,本发明在不添加其他附加模块的情况下提升模型整体性能,能够获得更精确的图像分类结果。
附图说明
11.为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他附图。
12.图1为本发明实施例提供的一种域适应图像分类网络的训练方法的流程图;图2为本发明实施例提供的不同方案语义特征分布示意图;图3为本发明实施例提供的特征对比学习及概率对比学习框架图;图4为本发明实施例提供的添加不同对比学习方法后的特征分布对比结果示意图;图5为本发明实施例提供的一种域适应图像分类网络的训练系统的示意图;图6为本发明实施例提供的一种处理设备的示意图。
具体实施方式
13.下面结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明的保护范围。
14.首先对本文中可能使用的术语进行如下说明:术语“和/或”是表示两者任一或两者同时均可实现,例如,x和/或y表示既包括“x”或“y”的情况也包括“x和y”的三种情况。
15.术语“包括”、“包含”、“含有”、“具有”或其它类似语义的描述,应被解释为非排它性的包括。例如:包括某技术特征要素(如原料、组分、成分、载体、剂型、材料、尺寸、零件、部件、机构、装置、步骤、工序、方法、反应条件、加工条件、参数、算法、信号、数据、产品或制品等),应被解释为不仅包括明确列出的某技术特征要素,还可以包括未明确列出的本领域公知的其它技术特征要素。
16.下面对本发明所提供的一种域适应图像分类网络的训练方法、系统、设备及存储介质进行详细描述。本发明实施例中未作详细描述的内容属于本领域专业技术人员公知的现有技术。本发明实施例中未注明具体条件者,按照本领域常规条件或制造商建议的条件进行。
17.实施例一 本发明实施例提供一种域适应图像分类网络的训练方法,与现有方案中直接用
分类器前的特征计算对比损失不同,本发明利用分类器后的概率计算概率对比损失,以便在聚类同类别的特征的同时有效地约束类权重和特征之间的距离。具体来说,本发明将对比学习从特征空间迁移到概率空间,并删除l2范数归一化来约束概率呈现one-hot形式。如图1所示,本发明提供的方法主要包括如下步骤:步骤1、获取源域图像集合,并根据训练方式获取目标域图像集合;对目标域图像集合中每一未标注的目标域图像进行两种不同的图像变换,获得的第一变换图像与第二变换图像构成一个目标域图像对,由所述源域图像集合、目标域图像集合以及所有目标域图像对构成训练数据集。
18.本发明实施例中,可以从现有公开数据集中收集源域图像集合与目标域图像集合,源域图像集合中所有源域图像均具有对应的类别标签。目标域图像集中所有目标域图像均为未标注图像,或者一部分是带有类别标签的目标域图像,另一部分是未标注的目标域图像;具体来说:当使用半监督训练方式时,一部分目标域图像(具体数量可根据实际情况或者经验自行设定)有对应的类别标签,另一部分是未标注的目标域图像;当使用无监督训练方式时,所有目标域图像都没有对应的类别标签,即均为未标注的目标域图像。本步骤中,针对每一未标注的目标域图像进行两种不同的图像变换,显然,如果使用无监督训练方式,那么所有目标域图像都需要进行两种不同的图像变换;两种不同的图像变换方式可以采用已有的常规图像变换方法实现。
19.步骤2、将所述训练数据集输入至所述域适应图像分类网络。
20.本发明实施例中,域适应图像分类网络主要包含特征提取器与分类器,以及softmax层。其中,特征提取器与分类器可以根据需要使用目前已有的网络结构来实现。
21.步骤3、根据训练方式以及设定的损失类别,利用所述域适应图像分类网络特征提取器的输出、分类器的输出与softmax层的输出中的一种或多种计算相应的类别损失,联合计算出的所有类别的损失,构成所述域适应图像分类网络的基线损失。
22.本领域技术人可以理解,特征提取器的输出是从输入图像中提取的图像特征,分类器的输出为logits,它是一种对于输入图像输出的未归一化的分类概率;softmax层的输出是对分类器的输出进行转换得到的概率向量。此处所述的输入图像是指训练数据集的图像。根据训练方式以及设定的损失类别,利用以上三类输出中的一种或多种计算相应的损失。
23.本发明实施例中,所述基线损失包含了多个类别的损失,根据设定的损失类别,需要使用训练数据集中的三部分图像(即源域图像集合中的图像、目标域图像集合中的图像,以及目标域图像对中的图像)。例如,针对源域图像与目标域图像的分类损失(利用分类器的输出计算),针对源域图像与目标域图像计算的对抗损失(利用特征提取器的输出、分类器的输出或者softmax层的输出计算),针对目标域图像计算的最小最大化熵损失(利用分类器的输出计算)等等。本步骤中基线损失所包含的损失类型以及各类型损失的计算方式都可参照常规技术。
24.为了便于理解,下面针对上文列举的各类损失的计算原理进行说明,未进行详细介绍的部分(例如,具体的计算公式)都可参照常规技术。
25.1、源域图像与目标域图像的分类损失。
26.1)有标注图像上的分类损失:利用有标注图像对应的分类器的输出,以及图像对
应的类别标签计算分类损失,例如,计算分类器的输出与类别标签的交叉熵损失作为分类损失。考虑到不同的训练方式,有标注图像存在一定的区别:使用无监督训练方式时,有标注图像是指源域图像集合中的源域图像;使用半监督训练方式时,有标注图像包括:源域图像集合中的源域图像,以及目标域图像集合中带有类别标签的目标域图像。
27.2)未标注图像上的分类损失:未标注图像主要是指目标域图像集合中未标注的目标域图像(无监督训练方式时,整个目标域图像集合均为未标注的目标域图像),以及目标域图像对中的第一变换图像与第二变换图像。通过伪标签生成技术生成未标注的目标域图像的伪标签,利用未标注的目标域图像对应的分类器的输出,以及相应的伪标签计算分类损失。此处所涉及的伪标签生成技术可使用常规技术来实现,实现方式有多种。例如,单独训练一个分类模型,利用分类模型为未标注的目标域图像进行分类,分类结果作为伪标签,利用伪标签来计算未标注的目标域图像对应目标域图像对的分类损失;当然,也可以利用分类模型为目标域图像对的第一变换图像进行分类,分类结果作为伪标签来计算同一目标域图像对的中第二变换图像的分类损失。当然,也可以不计算此部分分类损失,具体的由用户自行设定。
28.2、对抗损失。
29.对抗损失是为了使得源域图像与目标域的图像的分布接近,目标域的图像可以是目标域图像集合中的目标域图像,也可以是目标域图像对中的各个变换图像;计算对抗损失时,会引入一个判别器,通过判别器对输入的源域图像与目标域图像的信息进行判别,以区分源域与目标域,所述源域图像与目标域图像的信息为特征提取器的输出、分类器的输出或softmax层的输出。
30.3、最小最大化熵损失。
31.最小最大化熵损失计算对象是目标域的图像,它是一种基于对目标域的图像的条件熵以及任务损失的优化极小极大损失,能够减少分布差异,又能学习任务具有区别性的特征。与对抗损失类似的,目标域的图像可以是目标域图像集合中的目标域图像,也可以是目标域图像对中的各个变换图像。
32.计算以上对抗损失与最小最大化熵损失,无需使用图像的标注信息(类别标签),也就是说,无需考虑半监督或有监督训练方式。
33.需要说明的是,基线损失可以包括图像的分类损失,以及对抗损失和/或最小最大化熵损失,当然还可以包括其他类别的损失,最终构成所述域适应图像分类网络的基线损失。本领域技术人员可以理解,基线损失属于现有域适应图像分类网络的已有损失,本发明的核心是在域适应图像分类网络已有损失的基础上设计了一个新的损失(也即后文计算的总的概率对比损失),在不明显增加训练负担的条件下,提升图像分类的准确率。
34.步骤4、对于每一个目标域图像对,提取出所述域适应图像分类网络中softmax层的输出,构成一对概率向量;将每一对概率向量中第一变换图像对应的概率向量作为第一查询向量,将与第一查询向量不属于同一对概率向量中的其他所有概率向量作为相应第一查询向量的负样本,以及将每一对概率向量中第二变换图像对应的概率向量作为第二查询向量,将与第二查询向量不属于同一对概率向量中的其他所有概率向量作为相应第二查询向量的负样本;利用所有第一查询向量及对应的负样本,以及所有第二查询向量及对应的负样本计算总的概率对比损失。
35.本发明实施例中,针对未标注目标域图像,设计了一种概率对比损失,结合概率对比损失与目前已有损失(即步骤3计算的基线损失)进行网络训练,能够极大的提升网络性能。通俗的理解,可以将步骤3与步骤4作为两个线程,步骤3中根据需要利用域适应图像分类网络各个部分的输出计算各类损失,构成基线损失,步骤4则从域适应图像分类网络末端的softmax层提取每一目标域图像对对应的一对概率向量计算概率对比损失,最终形成总的概率对比损失。本步骤的优选实施方式如下: 将一个目标域图像对记为,其中,xi表示第一变换图像,表示第二变换图像,i为目标域图像的索引符号;一个目标域图像对经所述域适应图像分类网络中的特征提取器,提取出一对图像特征,再经所述域适应图像分类网络中的分类器,获得一对logits,其中,w表示分类器的权重参数,t为转置符号;一对logits 经softmax层转换为一对概率向量,其中,pi=( p
i,1
,

,pi,c), ,pi表示第一变换图像xi的概率向量,p
i,c
表示第一变换图xi属于类别c的概率,表示第二变换图像的概率向量,表示第二变换图像属于类别c的概率,,c为总类别数。
36.以一对概率向量为例,将pi作为所述第一查询向量,将与pi不属于同一对概率向量中的其他所有概率向量作为相应第一查询向量的负样本,也就是说,将除去及其自身(pi)之外的其他所有概率向量均作为第一查询向量pi的负样本;将作为所述第二查询向量,将与第二查询向量不属于同一对概率向量中的其他所有概率向量作为相应第二查询向量的负样本,也就是说,将除去pi及其自身()之外的其他所有概率向量均作为第二查询向量的负样本;所述第一查询向量与第二查询向量构成一个查询向量对,利用所述第一查询向量pi及对应的负样本计算第一概率对比损失,利用所述第二查询向量及对应的负样本计算第二概率对比损失,计算方式表示为:计算方式表示为:其中,pj表示第j个目标域图像对应的第一变换图像xj的概率向量,表示第k个目标域图像对应的第二变换图像的概率向量,s为比例系数,t为转置符号。
37.联合所述第一概率对比损失与所述第二概率对比损失,作为一个查询向量对计算的概率对比损失:;综合所有查询向量对计算的概率对比损失,作为总的概率对比损失l
pcl
:需要说明的是,步骤3与步骤4不区分执行的先后顺序,二者可以同步执行,也可以按照需要先后执行。
38.步骤5、联合所述基线损失与总的概率对比损失训练所述域适应图像分类网络。
39.本发明实施例中,联合所述基线损失与总的概率对比损失构建总损失函数,利用所述总损失函数训练所述域适应图像分类网络;所述总损失函数表示为:l=l
ori
λl
pcl
其中,l
ori
表示基线损失,l
pcl
表示总的概率对比损失,λ为权重系数。
40.本发明实施例上述方案,在常规的基线损失基础上,添加一个简洁的对比学习方法,在基本不改变模型结构的条件下,聚类相同语义的特征,解决域间迁移的问题,从而提高图像分类的准确度。
41.为了更加清晰地展现出本发明所提供的技术方案及所产生的技术效果,下面针对本发明上述方法改进的原理进行说明。
42.考虑到,基线损失是由目前域适应图像分类网络训练过程中计算的各类损失构成,因此不再赘述,下面主要针对本发明设计的概率对比损失原理以及其相对于现有特征对比损失的优越性进行说明。
43.众所周知,在缺少目标域数据标签的情况下,目标域中的各种类别难以区分,如图2的(a)部分所示,展示了各数据的初始概率分布(initial distibution)。与此同时,现有的基于infonce损失的特征对比学习方法(fcl)可以学习无标签数据的语义紧凑特征表示,这意味着对比学习倾向于将语义相似的特征聚类在一起。且大量工作表明,对比学习方法训练的模型具有良好的可迁移性。因此,一个简单的想法是在源域使用全监督学习而在目标域使用对比学习,从而使目标域中每个类别都可以通过语义紧凑的表示区分,如图2的(b)部分所示。然而直接应用上述方法的增益非常有限,在一些具有强一致性约束的模型上甚至无明显增益。究其原因,对于包含特征提取器e和分类器f的模型,当前的对比学习方法通常使用分类器前的特征来计算对比损失,而优化的过程中不涉及类别权重信息,这导致特征对比学习的聚类结果不是围绕类权重的。对于域适应任务,在源域和目标域之间存在一个明显的域偏移问题,使得源域的类权重难以定位在目标域数据的类中心。因此,本发明采用概率对比学习方式(pcl),在优化过程中获取类信息,使特征聚类在类权重附近,达到如图2的(c)部分所示的效果。图2的三个部分中,深色的三角形与五角星表示源域图像(source data),浅色的三角形与五角星表示目标域图像(target data),带有圆圈的三角形与五角星表示分类权重(class weight),三角形与五角星用于区分不同的类别。
44.如图3所示,展示了现有的基于infonce损失的对比学习方法中使用的特征对比学习及与本发明使用的概率对比学习的框架图。图3中,底部的xi表示第i个目标域图像,
encoder表示编码器,用于提取图像特征,classifier表示分类器,根据输入的图像特征输出logits,maximize agreement表示最大化左右两侧的与、pi与;图3的(a)部分为现有的基于infonce损失的对比学习方法中使用的特征对比学习框架,图3的(b)部分为本发明使用的概率对比学习的框架。如图3的(b)部分所示,本发明的整个技术架构十分简洁,由于特征与类权重的距离越小,其对应的概率就越类似于一个one-hot形式,因此希望在优化对比损失时,特征对应的概率值会逼近one-hot形式。经分析发现,只需将特征替换为概率且对常规损失函数做一定修改,即可自动使特征的概率近似于一个one-hot形式,即只需要将原始的特征对比学习中的特征转换为概率并且删除l2范数归一化(l
2-morm)。
45.记为一批目标域图像对,其中n为批尺寸,为一个目标域图像对,如之前所述,通过对目标域图像xi进行变换,得到变换图像;定义域适应图像分类网络,其中e为特征提取器(例如,编码器encoder),f为分类器;特征提取器从一批目标域图像对b中提取特征;分类器有参数w=(w1,

,wc),其中c为类别数,wc为第c个类的类权重。对于一个查询特征fi,特征为正样本,其余所有样本均为负样本,infonce损失为: 其中,f、为相应图像的特征,角标j与k为相应目标域图像的索引符号;s为比例系数, 为标准l2范数归一化操作。
46.另一方面,考虑到在infonce损失中,特征fi中不涉及类权重信息,因此不可能在优化过程中将特征集中到分类权重周围。而直接使用分类器后的得分进行对比学习也无法约束类权重与特征之间的距离,导致学习效果不佳,因此,需要用包含权重信息的新特征替换特征fi。将新的特征记为fi',本发明试图使用新的特征fi'计算对比损失来使特征fi接近分类权重。使用新的特征fi'后,的损失函数定义为:新的特征fi'的设计目的是希望上述损失越小,特征fi与分类权重越接近。最小化上式的一个可行解法为最大化。由于特征fi与分类权重越接近,特征fi对应的概率向量pi=softmax(w
tfi
)与one-hot形式越相似,即pi=(0,

,1,

,0)。也就是说,可以通过强制特征的输出概率近似于一个one-hot形式来缩小特征和类权重之间的距离。
47.同时,注意到,对于概率向量pi=( p
i,1
,

,pi,c)以及有: 且pi及的l1范数归一化等于1,即有及,显然有:当且仅当时等号成立,此时二者均为one-hot形式。换言之,为了最大化,需要二者同时满足one-hot形式,因此,损失函数中新的特征fi'可以直接定义为概率向量pi。从推导过程中可以看出,概率向量的l1范数归一化等于1这一特性保证了只有当pi及同属为one-hot形式时,取得最大值,故而不能使用传统特征对比学习中的l2范数归一化操作。
48.基于上述原理,本发明提出的新的对比损失(即第一概率对比损失)为:上述新的对比损失与infonce损失相比有两个主要的不同:1)上式使用概率向量pi代替了特征提取器输出的特征fi来进行对比学习;其次,为了保证概率呈现one-hot形式,移除l2范数归一化操作g。
49.基于同样的原理,可以计算出第二概率对比损失,从而得到总的概率对比损失l
pcl
,再结合基线损失l
ori
,计算总损失函数:l=l
ori
λl
pcl
其中,λ表示权重系数。
50.下面结合本发明提供的上述方法,介绍完整的训练与测试过程。
51.一、准备训练数据集和测试集。
52.首先,获取源域图像集合,并根据训练方式获取目标域图像集合。其中,所有源域图像均具有对应的标签;根据训练方式的不同,所有目标域图像均为未标注图像,或者一部分是带有类别标签的目标域图像,另一部分是未标注的目标域图像;对每一未标注的目标域图像随机使用两种不同的常规图像变换,获得的第一变换图像与第二变换图像构成一个目标域图像对。一个目标域图像对中的两个变换图像作为彼此的正样本,其余变换图像均视为负样本,由所述源域图像集合、目标域图像集合以及所有目标域图像对构成训练数据集。
53.二、使用深度学习框架,建立基于概率对比学习的域适应图像分类网络,主要包括:特征提取器、分类器及softmax层,其中前两个模块可以是当前的主流分类网络。
54.示例性的,对于无监督训练方式,使用在imagenet上预训练的resnet-50模型作为主干网络;对于半监督训练方式,使用移除最后一个线性层的alexnet及resnet-34,并在其
后添加一个新的分类器f。
55.三、将训练数据集输入至域适应图像分类网络,计算出基线损失l
ori
。基线损失包含了若干损失,具体的损失类型、损失数据可根据训练方式与用户需求进行设定,利用域适应图像分类网络各部分的输出计算相应损失,所涉及的各损失计算方案在前文已经做了详细的介绍,故不再赘述。
56.四、对于目标域的每一目标域图像对,提取出softmax层的输出,得到一对概率向量,按照前述步骤4介绍的方式计算总的概率对比损失。
57.五、根据前述三与四的两部分损失构建总损失函数,通过反向传播算法以及小批量随机梯度下降法,使得总损失函数l最小化,更新各特征提取器和分类器的权重。其中,在最小化提出的概率对比损失后,数据集中每个语义类对应的特征均会聚类且聚集在类权重周围,为了便于直观的呈现本发明上述方案的效果,下面以源域图像的分类损失与最小最大化熵损失(mme)作为基线损失为例进行了对比实验,可视化效果如图4所示,(a)、(b)、(c)三个部分分别展示了仅使用基线损失、使用基线损失 fcl(基于infonce损失的特征对比学习方法)、使用基线损失 pcl(即本发明提出的方法)的特征聚类结果;图4中,basket、bathtub表示两种类别,(a)部分左下角部分对应bathtub类别,右上角对应basket类别;(b)部分右下角部分对应bathtub类别,左上角对应basket类别;(c)部分左下角部分对应basket类别,右上角对应bathtub类别;各个类别中的圆形符号表示分类权重。与其他方法对比,可以看出本发明所提出的方法有效的聚类了特征,且相较于直接添加特征对比损失而言,显著减少了特征与类权重的距离。需要说明的是,此处的对比实验中基线损失虽然只包含源域图像的分类损失与最小最大化熵损失,但是,在实际应用中,基线损失所包含的损失类型、损失数目都可以根据实际情况来设定。
58.六、输入测试数据集(由目标域图像构成),计算训练后的域适应图像分类网络分类准确度。
59.本发明实施例提供的上述方案,主要获得如下有益效果:1)将对比学习方法引入无监督或半监督域适应图像分类任务中,将相同语义的特征聚类,解决该任务在目标域标签不足的问题;2)本发明将特征对比学习改进为概率对比学习,通过在概率空间进行对比学习,迫使概率向量逼近one-hot形式,减小聚类后的同语义特征与类权重之间的距离,提高分类的准确率;3)改进后的分类网络是简洁而有效的,本发明只在传统的域适应分类网络上添加了一个概率对比学习损失,并未添加复杂的附加模块,改进后网络的参数量与之前的方法相比没有增加。总体来说,本发明在不添加其他附加模块的情况下提升模型整体性能,能够获得更精确的图像分类结果。实施例二本发明还提供一种域适应图像分类网络的训练系统,其主要基于前述实施例一提供的方法实现,如图5所示,该系统主要包括:训练数据集构造单元,用于获取源域图像集合,并根据训练方式获取目标域图像集合;对目标域图像集合中每一未标注的目标域图像进行两种不同的图像变换,获得的第一变换图像与第二变换图像构成一个目标域图像对,由所述源域图像集合、目标域图像集合以及所有目标域图像对构成训练数据集;训练数据集输入单元,用于将所述训练数据集输入至所述域适应图像分类网络;
基线损失计算单元,用于根据训练方式以及设定的损失类别,利用所述域适应图像分类网络特征提取器的输出、分类器的输出与softmax层的输出中的一种或多种计算相应的类别损失,构成所述域适应图像分类网络的基线损失;总的概率对比损失计算单元,对于每一个目标域图像对,提取出所述域适应图像分类网络中softmax层的输出,构成一对概率向量;将每一对概率向量中第一变换图像对应的概率向量作为第一查询向量,将与第一查询向量不属于同一对概率向量中的其他所有概率向量作为相应第一查询向量的负样本,以及将每一对概率向量中第二变换图像对应的概率向量作为第二查询向量,将与第二查询向量不属于同一对概率向量中的其他所有概率向量作为相应第二查询向量的负样本;利用所有第一查询向量及对应的负样本,以及所有第二查询向量及对应的负样本计算总的概率对比损失;训练单元,用于联合所述基线损失与总的概率对比损失训练所述域适应图像分类网络。
60.所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将系统的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。
61.需要说明的是,上述系统中各单元的主要原理在之前的实施例一中已经做了详细的介绍,故不再赘述。
62.实施例三本发明还提供一种处理设备,如图6所示,其主要包括:一个或多个处理器;存储器,用于存储一个或多个程序;其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器实现前述实施例提供的方法。
63.进一步的,所述处理设备还包括至少一个输入设备与至少一个输出设备;在所述处理设备中,处理器、存储器、输入设备、输出设备之间通过总线连接。
64.本发明实施例中,所述存储器、输入设备与输出设备的具体类型不做限定;例如:输入设备可以为触摸屏、图像采集设备、物理按键或者鼠标等;输出设备可以为显示终端;存储器可以为随机存取存储器(random access memory,ram),也可为非不稳定的存储器(non-volatile memory),例如磁盘存储器。
65.实施例四本发明还提供一种可读存储介质,存储有计算机程序,当计算机程序被处理器执行时实现前述实施例提供的方法。
66.本发明实施例中可读存储介质作为计算机可读存储介质,可以设置于前述处理设备中,例如,作为处理设备中的存储器。此外,所述可读存储介质也可以是u盘、移动硬盘、只读存储器(read-only memory,rom)、磁碟或者光盘等各种可以存储程序代码的介质。
67.以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。
再多了解一些

本文用于企业家、创业者技术爱好者查询,结果仅供参考。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献