一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

模型训练方法、装置、计算机设备及存储介质与流程

2022-06-11 02:18:32 来源:中国专利 TAG:


1.本技术涉及人工智能技术领域,更具体地,涉及一种模型训练方法、装置、计算机设备及存储介质。


背景技术:

2.随着科技水平的迅速发展,多标签分类引起了人们极大的研究兴趣,并同时在很多应用产品中部署,智能化地解决了很多日常生活中的问题。多标签分类首先需要确定出需要分类的类别,再由人工对数据集标注出样本数据中所拥有的类别,之后使用标注后的数据集来训练模型。但是,完备的多标签数据集的获取成本极高,因为当类别数量过多时,数据集的标注往往是不完整的,导致训练得到的模型的准确性不足。


技术实现要素:

3.本技术提出了一种模型训练方法、装置、计算机设备及存储介质,可以实现模型训练过程中自动修正负样本所带来的损失,提升模型训练的效率和准确性。
4.第一方面,本技术实施例提供了一种模型训练方法,所述方法包括:获取训练样本集,所述训练样本集包括多个样本数据,所述训练样本集包括多个样本数据标签的标签概率;分别将所述训练样本集中每个样本数据输入至初始识别模型,得到每个样本数据对应的每个候选标签的标签概率;将所述训练样本集中的目标样本数据作为目标候选标签的正样本数据,并基于所述训练样本集中每个样本数据对应的每个候选标签的标签概率以及所述每个样本数据标注的标签,确定总损失值,所述目标样本数据未标注有所述目标候选标签,且所述目标样本数据对应所述目标候选标签的标签概率大于预设概率;根据所述总损失值,对所述初始识别模型进行迭代训练,得到训练后的识别模型。
5.第二方面,本技术实施例提供了一种模型训练装置,所述装置包括:样本获取模块、模型输入模块、损失确定模块以及迭代训练模块,其中,所述样本获取模块用于获取训练样本集,所述训练样本集包括多个样本数据;所述模型输入模块用于分别将所述训练样本集中每个样本数据输入至初始识别模型,得到每个样本数据对应的每个候选标签的标签概率,所述多个样本数据中包括至少一个候选标签对应的负样本数据;所述损失确定模块用于将所述训练样本集中的目标样本数据作为目标候选标签的正样本数据,并基于所述训练样本集中每个样本数据对应的每个候选标签的标签概率以及所述每个样本数据对应的标签,确定总损失值,所述目标样本数据未标注有所述目标候选标签,且所述目标样本数据对应所述目标候选标签的标签概率大于预设概率;所述迭代训练模块用于根据所述总损失值,对所述初始识别模型进行迭代训练,得到训练后的识别模型。
6.第三方面,本技术实施例提供了一种计算机设备,包括:一个或多个处理器;存储器;一个或多个应用程序,其中所述一个或多个应用程序被存储在所述存储器中并被配置为由所述一个或多个处理器执行,所述一个或多个应用程序配置用于执行上述第一方面提供的模型训练方法。
7.第四方面,本技术实施例提供了一种计算机可读取存储介质,所述计算机可读取存储介质中存储有程序代码,所述程序代码可被处理器调用执行上述第一方面提供的模型训练方法。
8.本技术提供的方案,通过获取训练样本集,分别将训练样本集中每个样本数据输入至初始识别模型,得到每个样本数据对应的每个候选标签的标签概率,然后将训练样本集中的目标样本数据作为目标候选标签的正样本数据,并基于训练样本集中每个样本数据对应的每个候选标签的标签概率以及每个样本数据标注的标签,确定总损失值,该目标样本数据未标注有所述目标候选标签,且目标样本数据对应目标候选标签的标签概率大于预设概率,再根据总损失值,对初始识别模型进行迭代训练,得到训练后的识别模型。由此,可以实现模型训练中,对于标注不完整的负样本数据,自动按照正样本计算损失,实现自动修正负样本所带来的损失,从而提升模型训练的效率和准确性。
附图说明
9.为了更清楚地说明本技术实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本技术的一些实施例,对于本领域技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
10.图1示出了根据本技术一个实施例的模型训练方法流程图。
11.图2示出了本技术实施例提供的模型训练方法的原理示意图。
12.图3示出了根据本技术另一个实施例的模型训练方法流程图。
13.图4示出了根据本技术又一个实施例的模型训练方法流程图。
14.图5示出了根据本技术一个实施例的模型训练装置的一种框图。
15.图6是本技术实施例的用于执行根据本技术实施例的模型训练方法的计算机设备的框图。
16.图7是本技术实施例的用于保存或者携带实现根据本技术实施例的模型训练方法的程序代码的存储单元。
具体实施方式
17.为了使本技术领域的人员更好地理解本技术方案,下面将结合本技术实施例中的附图,对本技术实施例中的技术方案进行清楚、完整地描述。
18.自然界中,图片大多蕴含不同的物品、场景、关系等。相比于单标签分类中图像上只有一个类别,多标签分类指的是识别出图像上所有需要的目标、场景等类别,更加适应于自然界图像的分布和人们的使用习惯。多标签图像分类首先需要确定出需要分类的类别,再由人工对数据集标注出图像中所拥有的类别,之后使用标注后的数据集来训练模型。完备的多标签数据集的获取成本极高,因为当类别数量过多时,由于人工自身标注时的偏见和疏漏,数据集的标注往往是不完整的。
19.因此,使用不完备的数据集训练多标签分类是该领域的常态,这种“脏”数据集会对多标签分类模型的识别精度造成严重的影响。相关技术中,通常使用伪标签(pseudo label)的方式进行标签修正,以提升训练的模型的精度。其中,伪标签的方式中,先使用标
签缺失的数据集来训练一个模型,再使用这个训练出来的模型对数据集打上缺失的标签,判断模型预测值是否超过设定的阈值,最后利用补标签后的数据集,再次训练模型。但是,使用标注不完备的数据集(即噪声数据)训练后期的模型来打伪标签,该模型极大可能过拟合了带噪数据,导致标注的伪标签精度一般;并且,训练过程中,需要先训练打标签模型,然后再训练新模型,需要多次训练模型,影响了模型训练效率,带来了额外的训练开销。
20.针对上述问题,发明人提出了本技术实施例提供的模型训练方法、装置、计算机设备以及存储介质,可以实现模型训练中,对于标注不完整的负样本数据,自动按照正样本计算损失,实现自动修正负样本所带来的损失,从而提升模型训练的效率和准确性。其中,具体的识别模型的训练方法在后续的实施例中进行详细的说明。
21.下面再结合附图对本技术实施例提供的模型训练方法进行详细介绍。
22.请参阅图1,图1示出了本技术一个实施例提供的模型训练方法的流程示意图。在具体的实施例中,所述模型训练方法应用于如图5所示的模型训练装置400以及配置有所述模型训练装置400的计算机设备100(图6)。下面将以计算机设备为例,说明本实施例的具体流程,当然,可以理解的,本实施例所应用的计算机设备可以为服务器、智能手机、平板电脑、智能手表、电子书、pc电脑、笔记本电脑等,在此不做限定。下面将针对图1所示的流程进行详细的阐述,所述模型训练方法具体可以包括以下步骤:
23.步骤s110:获取训练样本集,所述训练样本集包括多个样本数据,所述多个样本数据中包括至少一个候选标签对应的负样本数据。
24.其中,训练样本集中的样本数据可以为图像、音频等类型的数据。可选地,在需求识别模型识别图像中存在的目标、场景等多个候选标签(即待识别的类别)时,例如需求识别图像中是否存在人、树、天空、猪、羊等候选标签时,则样本数据可以为样本图像,样本图像可以来自于coco dataest数据集,imageclef数据集等;可选地,在需求音频数据中是否存在不同类别的声音时,例如需求识别音频中是否存在人声、汽车声、飞机声、广播声等多个候选标签时,则样本数据可以为音频数据,音频数据可以从服务器获取,或者从不同用户终端搜集得到,具体获取方式可以不做限定。
25.在本技术实施例中,训练样本集的多个样本数据中可以包括至少一个候选标签对应的负样本数据。可选地,训练样本集包括多个正第一样本数据以及多个负第二样本数据,其中,第一正样本数据被标注有其存在的所有候选标签,第二负样本数据未被标注有其不存在的候选标签,或者未标注有其存在的候选标签中的至少一个候选标签。可以理解地,在需求识别多个候选标签的识别场景中,完备的多标签数据集的获取成本极高,因为当标签种类过多时,由于人工自身标注时的偏见和疏漏,数据集的标注往往是不完整的,因此上述训练样本集中可以包括标注完整的第一样本数据,另外,一些样本数据中本身不存在候选标签,或者存在候选标签却未被标注,即第二样本数据。
26.其中,训练样本集中的样本数据中,针对任一候选标签,若样本数据未标注有该候选标签,则该样本数据对于该候选标签而言为负样本数据;若样本数据标注有该候选标签,则该样本数据对于该候选标签而言为正样本数据。可以理解的是,负样本数据中,可能为上述第二样本数据中本身不存在候选标签的样本数据,也可能由于人工自身标注时的偏见和疏漏,而导致漏标的样本数据,漏标的样本数据也可以被认为是假的负样本数据。
27.在一些实施方式中,上述样本数据中,可以包括存在各个候选标签的正样本数据
logical regression),也可以是支持向量机(support vector machine,svm)等,具体的分类模块可以不作为限定。
37.可选地,假设共有k个候选标签,x∈rk为模型的输出,其中第i个元素xi预测第i个候选标签存在的可能性。用sigmoid函数归一化每个xi,得到标签概率pi即输入模型的数据中存在第i个标签的概率,pi越大,则输入数据中存在第i个标签的可能性越大,反之,输入数据中存在第i个标签的可能性越小。
38.步骤s130:将所述训练样本集中的目标样本数据作为目标候选标签的正样本数据,并基于所述训练样本集中每个样本数据对应的每个候选标签的标签概率以及所述每个样本数据标注的标签,确定总损失值,其中,所述目标样本数据未标注有所述目标候选标签,且所述目标样本数据对应所述目标候选标签的标签概率大于预设概率。
39.在得到初始识别模型输出的结果之后,则可以基于初始识别模型输出的结果,以及各个样本数据标注的标签,确定初始识别模型的总损失值,以便根据总损失值对初始识别模型进行更新。在本技术实施例中,在确定总损失值时,由于训练样本集中存在漏标的样本数据,即样本数据中本身存在候选标签,但并未标注该候选标签,因此在获取总损失值时,需要考虑该影响。
40.其中,在k个候选标签的识别场景中,令yi∈{0,1}为第i个候选标签的监督(即样本数据的标签),若样本数据标注有该标签,则yi=1,反之则yi=0,那么,损失值可按照该公式计算:
[0041][0042]
其中,和分别表示第i个候选标签的正样本损失值和负样本损失值。
[0043]
在该公式中,对于第i个候选标签而言,若一个样本数据标注有该候选标签,则yi=1,代入该公式计算时,该样本数据对于该候选标签所产生的损失为:即按照该候选标签的正样本损失值的计算方式计算该损失;若一个样本数据未标注有该候选标签,则yi=0,代入该公式计算时,该样本数据对于该候选标签所产生的损失为:即按照该候选标签的负样本损失值的计算方式计算该损失。
[0044]
计算机设备获取总损失值时,可以根据各个样本数据对应的每个候选标签的标签概率,针对每个样本数据未标注的各个候选标签,确定是否其对应的标签概率大于预设概率;若任一未标注的候选标签中,其对应的标签概率大于预设概率,则表示该样本数据存在漏标,即假的负样本数据,实际计算损失时,应当按照目标候选标签的正样本数据,计算损失,由此避免因漏标导致按照负样本数据计算损失值,进而避免模型按照错误的损失值更新参数,而避免模型准确性受到影响;而对于其他的初始识别模型输出的标签概率,则按照通常的损失计算方式计算损失,即损失计算保持与标注的标签一致。其中,预设概率可以为0.5、0.6、0.7、0.7等,具体数值可以不作为限定;目标候选标签为任一候选标签。
[0045]
可以理解地,通常多个候选标签的识别问题可看作并行的多元二分类问题,即依次预测每个候选标签在数据中存在的概率,即存在或者不存在。这意味着,在多标签学习
中,负样本是远多于正样本的,正负样本严重失衡,而标签缺失问题(即正样本缺失标注的标签)会进一步加剧这种不平衡。在这种情况下,若模型仍能给出正样本的预测(即预测出候选标签的概率大于预设概率),则这个预测的可信度应该是很高的,因此很有可能是标签缺失的样本(即假的负样本数据)。当然随着训练的持续,模型过拟合至不完备的训练数据,就会失去这种分辨能力,因此可以利用模型早起的这种分辨能力来计算损失。也就是说,对于某一样本数据而言,若目标样本数据未标注目标候选标签,但模型输出该目标样本数据对应该目标候选标签的标签概率大于预设概率,则表示其缺失该目标候选标签,因此将其按照目标候选标签的正样本数据计算损失(即认为该目标样本数标注有目标候选标签,并根据该标签计算损失值),从而保证获取的损失值的准确性,也可以理解为对该目标样本数据的标签进行了修正,使得模型训练时,按照正确的标签计算损失值。因此,结合上述公式而言,若目标候选标签为第i个标签,则该目标样本数据对应该目标候选标签的损失值为:
[0046]
步骤s140:根据所述总损失值,对所述初始识别模型进行迭代训练,得到训练后的识别模型。
[0047]
在本技术实施例中,在获得总损失值之后,则可以根据总损失值对初始识别模型进行迭代训练,得到最终的识别模型。
[0048]
在一些实施方式中,可以根据计算得到的总损失值,调整初始识别模型的模型参数;返回步骤s120,重复步骤s120~s140,直至满足训练结束条件,得到训练后的识别模型。
[0049]
其中,在分别将训练样本集中每个样本数据输入至初始识别模型,得到每个样本数据对应的每个候选标签的标签概率,然后根据每个样本数据对应的每个候选标签的标签概率以及每个样本数据标注的标签,确定总损失值,再根据计算得到的总损失值,调整初始识别模型的模型参数之后,即完成了1个epoch;然后再返回步骤s120,完成下一个epoch,如此重复,即可完成多个epoch。其中,epoch指使用训练集中的全部样本训练的次数,通俗的讲epoch的值就是整个数据集被轮几次,1个epoch等于使用训练集中的全部样本训练1次。
[0050]
可以理解地,由于每个epoch中,计算损失值时,均对存在漏标的负样本所产生的损失进行了修正,即将其按照目标候选标签的正样本计算其产生的损失,因此,根据得到的总损失值调整初始识别模型的模型参数时,能够保证模型更新的准确性;在进行下一个epoch时,则是将样本数据输入至模型参数更新后的初始识别模型(也即更新后的模型参数传递到下一个epoch中),并再次重复根据初始识别模型的输出结果以及样本数据标注的标签,确定总损失值,并根据总损失值调整模型参数的过程,从而可以不断地更新初始识别模型的模型参数,并且由于每个epoch中,会针对漏标的负样本所产生的损失进行修正,进而保证每个epoch中更新模型参数的准确性,最终得到准确的训练后的识别模型。
[0051]
在一种可能的实施方式中,可以根据总损失值,使用adam优化器对初始识别模型进行迭代训练,直至初始识别模型的输出结果的损失值收敛,并将此时的模型进行保存,得到训练后的识别模型。其中,adam优化器,结合了adagra(adaptive gradient,自适应梯度)和rmsprop两种优化算法的优点,对梯度的一阶矩估计(first moment estimation,即梯度的均值)和二阶矩估计(second moment estimation,即梯度的未中心化的方差)进行综合考虑,计算出更新步长。
[0052]
在一些实施方式中,迭代训练的训练结束条件可以包括:迭代训练的次数达到目
标次数;或者初始识别模型的输出结果的总损失值满足设定条件。
[0053]
在一种具体实施方式中,收敛条件是让总损失值尽可能小,使用初始学习率1e-3,学习率随步数余弦衰减,batch_size=8,训练16个epoch后,即可认为收敛完成。其中,batch_size可以理解为批处理参数,它的极限值为训练集样本总数。
[0054]
在另一种具体实施方式中,总损失值满足设定条件可以包括:总损失值小于设定阈值。当然,具体设定条件可以不作为限定。
[0055]
在一些实施方式中,若识别模型在被应用时由电子设备执行,则训练得到的识别模型可以存储于电子设备本地;当然,该训练得到的识别模型也可以在与电子设备通信连接的服务器,将识别模型存储在服务器的方式,可以减少占用电子设备的存储空间,提升电子设备运行效率。
[0056]
在一些实施方式中,识别模型还可以周期性的或者不定期的获取新的样本数据,对该识别模型进行训练和更新。例如,在存在输入数据被误识别时,则可以将该输入数据作为样本数据,对样本数据进行标注后,通过以上训练方式,再进行训练,从而可以提升识别模型的辨识度和识别准确度。
[0057]
在一些实施方式中,由于识别模型是用于识别数据中存在的多个候选标签,因此当用户需求识别模型识别的标签发生变化时,还可以增加新的候选标签,或者删除某个候选标签;并根据变更后的候选标签,对识别模型再进行训练。
[0058]
值得说明的是,对识别模型的训练可以是根据获取的训练样本集预先进行的,后续在每次需要识别待识别数据中是否存在候选标签时,则可以利用训练得到的识别模型进行识别,而无需每次对识别模型进行训练。
[0059]
下面再通过图2对本技术实施例涉及的识别模型的训练方法进行介绍。
[0060]
在对初始识别模型的训练过程中,将样本数据输入至初始识别模型后,可以得到样本数据对应各个候选标签的标签概率;然后基于标签概率以及样本数据标识的标签,若目标样本数据未标注有目标候选标签,且该目标样本数据对应该目标候选标签的标签概率大于预设概率,则将其作为目标候选标签的正样本数据,对其对应的损失值进行修正,从而确定出总损失值;然后根据总损失值更新初始识别模型的模型参数,如此重复,直至满足训练结束条件,完成对初始识别模型的训练,得到训练后的识别模型。
[0061]
本技术实施例提供的模型训练方法,可以实现模型训练中,在每个训练epoch中,对于标注不完整的负样本数据,自动按照正样本计算损失,实现自动修正负样本所带来的损失,从而提升模型准确性;另外,由于训练过程中,是在每一个训练epoch中修正了负样本数据所带来的损失,相当于是修正了样本数据的标注,因此无需多次训练模型,减少了训练开销,提升了模型训练的效率。
[0062]
请参阅图3,图3示出了本技术另一个实施例提供的模型训练方法的流程示意图。该模型训练方法应用于上述电子设备,下面将针对图3所示的流程进行详细的阐述,所述模型训练方法具体可以包括以下步骤:
[0063]
步骤s210:获取训练样本集,所述训练样本集包括多个样本数据,所述多个样本数据中包括至少一个候选标签对应的负样本数据。
[0064]
步骤s220:分别将所述训练样本集中每个样本数据输入至初始识别模型,得到每个样本数据对应的每个候选标签的标签概率。
[0065]
在本技术实施例中,步骤s210以及步骤s220可以参阅前述实施例的内容,在此不再赘述。
[0066]
步骤s230:基于所述每个样本数据标注的标签,确定所述每个候选标签的正样本数据以及负样本数据。
[0067]
在本技术实施例中,计算机设备可以基于每个样本数据标注的标签,针对每个候选标签,确定每个候选标签的正样本数据以及负样本数据。其中,训练样本集中的样本数据中,针对任一候选标签,若样本数据未标注有该候选标签,则该样本数据对于该候选标签而言为负样本数据;若样本数据标注有该候选标签,则该样本数据对于该候选标签而言为正样本数据。
[0068]
步骤s240:针对每个候选标签,基于所述每个候选标签的每个正样本数据,以及所述每个正样本数据对应的候选标签的标签概率,确定每个候选标签对应的正样本损失值。
[0069]
在本技术实施例中,计算机设备在确定出每个候选标签的正样本数据后,则可以针对每个候选标签,计算每个候选标签对应的正样本损失值。其中,对于任一正样本数据,可以针对每个候选标签,基于每个候选标签的每个正样本数据,以及每个正样本数据对应的候选标签的标签概率,确定每个候选标签对应的正样本损失值。
[0070]
在一些实施方式中,计算机设备可以针对每个候选标签,基于候选标签的正样本数据对应所述候选标签的标签概率,并按照正样本损失值的确定方式获取损失值,作为正样本数据对应所述候选标签的正样本损失值。
[0071]
其中,正样本损失值可以基于交叉熵损失函数、焦点损失函数、或者自动损失函数确定。
[0072]
其中,交叉熵(binary cross entropy,bce)损失函数,是常用的多标签损失函数,它的定义为:
[0073][0074]
其中,为正样本损失值,为负样本损失值,p为标签概率,为方便描述,以及p均省去了下标i(第i个候选标签)。
[0075]
在通过交叉熵损失函数确定正样本损失值的方式中,在针对每个候选标签,基于候选标签的正样本数据对应所述候选标签的标签概率,并按照正样本损失值的确定方式获取损失值,作为正样本数据对应所述候选标签的正样本损失值时,则可以对于每个样本数据,针对每个候选标签,根据该样本数据对应每个候选标签的标签概率,并通过计算该样本数据对应每个候选标签的正样本损失值。
[0076]
焦点(focal)损失函数是交叉熵损失函数基础上的改进,用来解决正负样本不平衡的问题,并且加入了难样本挖掘机制,其公式为:
[0077][0078]
其中,为正样本损失值,为负样本损失值,p为标签概率,为方便描述,
以及p均省去了下标i(第i个候选标签);参数γ为调制系数的幂,γ>0时,难样本的权重较大,起到难样本挖掘的作用;参数α

和α-用于平衡正样本损失值与负样本损失值的数值关系。
[0079]
在通过focal损失函数确定正样本损失值的方式中,在针对每个候选标签,基于候选标签的正样本数据对应所述候选标签的标签概率,并按照正样本损失值的确定方式获取损失值,作为正样本数据对应所述候选标签的正样本损失值时,则可以对于每个样本数据,针对每个候选标签,根据该样本数据对应每个候选标签的标签概率,并通过计算该样本数据对应每个候选标签的正样本损失值。
[0080]
自动损失函数(auto seg-loss,asl)是对focal损失函数的改进,它能减轻正负样本不平衡带来的影响,其定义为:
[0081][0082]
其中,为正样本损失值,为负样本损失值,p为标签概率,为方便描述,以及p均省去了下标i(第i个候选标签);其中γ

、γ-分别是正负样本的参数,设m是一个概率阈值,pm表示为max(p-m,0);γ

、γ-使得简单负样本的权重减少,此外通过m的设置,丢弃了预测概率很低的负样本所产生的损失。
[0083]
在通过asl损失函数确定正样本损失值的方式中,在针对每个候选标签,基于候选标签的正样本数据对应所述候选标签的标签概率,并按照正样本损失值的确定方式获取损失值,作为正样本数据对应所述候选标签的正样本损失值时,则可以对于每个样本数据,针对每个候选标签,根据该样本数据对应每个候选标签的标签概率,并通过计算该样本数据对应每个候选标签的正样本损失值。
[0084]
当然,具体确定正样本损失值的方式并不局限于利用上述损失函数进行计算,也可以利用其他损失函数进行计算。
[0085]
步骤s250:针对每个候选标签,基于所述每个候选标签的每个负样本数据,以及所述每个负样本数据对应的候选标签的标签概率,确定每个候选标签对应的负样本损失值,其中,目标样本数据对应目标候选标签的负样本损失值按照正样本损失值的确定方式获取。
[0086]
其中,所述目标样本数据未标注有所述目标候选标签,且所述目标样本数据对应所述目标候选标签的标签概率大于预设概率。
[0087]
在本技术实施例中,计算机设备还可以针对每个候选标签,基于每个候选标签的每个负样本数据,以及每个负样本数据对应的候选标签的标签概率,确定每个候选标签对应的负样本损失值,但是,由于可能一些样本数据存在漏标,因此存在目标样本数据未标注的目标候选标签,但是目标样本数据对应目标候选标签的标签概率大于预设概率的情况,该情况下,则需要使得计算该目标样本数据对应目标候选标签的负样本损失值时,使其按照正样本损失值的计算方式进行,从而修正损失值,使其产生的损失值按照正确标注的标签而计算得到。
[0088]
在一些实施方式中,步骤s250可以包括:
[0089]
针对每个候选标签,若候选标签的负样本数据对应的标签概率小于或等于预设概率,则基于所述负样本数据对应所述候选标签的标签概率,并按照负样本损失值的确定方式获取损失值,作为所述负样本数据对应所述候选标签的负样本损失值;若候选标签的负样本数据对应的标签概率大于预设概率,则基于所述负样本数据对应所述候选标签的标签概率,并按照正样本损失值的确定方式获取损失值,作为所述负样本数据对应所述候选标签的负样本损失值。
[0090]
也就是说,对于每个负样本数据,计算候选标签对应的负样本损失时,由于可能一些样本数据存在漏标,因此候选标签的负样本数据对应的标签概率大于预设概率时,则按照正样本损失值的确定方式获取损失值,作为该负样本数据对应该候选标签的负样本损失值。而模型识别负样本数据为小于或等于预设概率时,则表示该负样本数据本身不存在该候选标签,因此按照负样本损失值的确定方式,获取损失值,作为该负样本数据对应该候选标签的负样本损失值。
[0091]
其中,按照正样本损失值的确定方式获取损失值可以参阅步骤s240中确定正样本损失值的方式;负样本损失值的确定方式可以采用交叉熵损失函数、焦点损失函数、或者自动损失函数中的负样本损失值的计算方式,例如,对于每个样本数据,针对每个候选标签,根据该样本数据对应每个候选标签的标签概率,并通过计算该样本数据对应每个候选标签的正样本损失值。
[0092]
步骤s260:基于所述每个候选标签对应的正样本损失值,以及所述每个候选标签对应的负样本损失值,确定总损失值。
[0093]
在本技术实施例中,在对于每个样本数据,确定出每个候选标签对应的正样本损失值,以及每个候选标签对应的负样本损失值之后,则可以对于每个样本数据,确定其对应产生的损失值,并获取每个样本数据所产生的损失值的总和,从而得到总损失值。
[0094]
其中,对于每个样本数据,确定其对应产生的损失值时,可以按照该公式计算:
[0095][0096]
和分别表示正样本损失值和负样本损失值,k为候选标签的数量,yi∈{0,1}为第i个候选标签的监督(即样本数据的标签),若样本数据标注有该标签,则yi=1,反之则yi=0。
[0097]
下面再通过定义的自步损失修正(self-paced loss correction,splc)损失函数说明本技术实施例中损失值的确定过程。
[0098]
对于每个样本数据,确定其对应产生的损失值时,可以按照该公式计算:
[0099][0100]
splc损失函数表示为:
[0101]
[0102]
其中,为便于表述,splc损失函数省去下标i,即为即为p为第i个候选标签的标签概率pi,
[0103]
π(p≤τ)用于对p的二值化,即p≤τ时,π(p≤τ)取值为1,反之,p>τ时,π(p≤τ)取值为0,τ为上述预设概率,
[0104]
loss

(p)以及loss-(p)通过上述的交叉熵损失函数、焦点损失函数、或者自动损失函数确定。
[0105]
例如,通过交叉熵损失函数时,则loss

(p)为loss-(p)为可以理解地,对于任一样本数据而言,若标注有第i个候选标签,则yi=1,此时,该样本数据对应第i个候选标签的损失为:-log(p),若想要损失值越小,则p的数值越大;若未标注有第i个候选标签,则yi=0,此时,该样本数据对应第i个候选标签的损失为:-log(1-p),若想要损失值越小,则p的数值越小。若样本数据被标注完备,即不存在样本数据缺失标签的情况下,按照则可以保证计算的损失值准确,即负样本数据而言,模型输出的概率应当越小,对于正样本数据而言,模型输出的概率应当越大。但是,若存在负样本数据缺失标签,此时模型会输出的概率大于预设概率,通过splc函数的约束,则计算负样本数据对应的负样本损失值时,则计算的损失值为:-log(p),由此,也相当于是,想要损失值越小,则p的数值越大,使计算的损失值符合该负样本数据的真实情况,从而保证准确性。
[0106]
步骤s270:根据所述总损失值,对所述初始识别模型进行迭代训练,得到训练后的识别模型。
[0107]
在本技术实施例中,步骤s270可以参阅前述实施例的内容,在此不再赘述。
[0108]
本技术实施例提供的模型训练方法,可以实现模型训练中,在每个训练epoch中,对于标注不完整的负样本数据,按照正样本损失值的确定方式计算该负样本数据所产生的损失值(即负样本损失值),使计算的损失值符合该负样本数据的真实情况,从而保证准确性。
[0109]
请参阅图4,图4示出了本技术又一个实施例提供的模型训练方法的流程示意图。该模型训练方法应用于上述计算机设备,下面将针对图4所示的流程进行详细的阐述,所述模型训练方法具体可以包括以下步骤:
[0110]
步骤s310:获取训练样本集,所述训练样本集包括多个样本数据,所述多个样本数据中包括至少一个候选标签对应的负样本数据。
[0111]
步骤s320:分别将所述训练样本集中每个样本数据输入至初始识别模型,得到每个样本数据对应的每个候选标签的标签概率。
[0112]
在本技术实施例中,步骤s310以及步骤s320可以参阅前述实施例的内容,在此不再赘述。
[0113]
步骤s330:获取所述训练样本集中的目标样本数据,所述目标样本数据未标注有所述目标候选标签,且所述目标样本数据对应所述目标候选标签的标签概率大于预设概率。
[0114]
在本技术实施例中,由于可能一些样本数据存在漏标,因此存在目标样本数据未标注有目标候选标签,但目标样本数据对应目标候选标签的标签概率大于预设概率,故计算机设备可以针对每个样本数据,基于其对应的每个候选标签的标签概率,针对每个样本数据未标注的各个候选标签,确定是否其对应的标签概率大于预设概率,若任一未标注的候选标签中,其对应的标签别概率大于预设概率,则可以将该样本数据作为目标样本数据;若每个未标注的候选标签,其对应的标签概率均小于或等于预设概率,则表示该样本数据为正确标注的数据,可以不对其进行处理。
[0115]
步骤s340:将所述目标样本数据修正为所述目标候选标签的正样本数据,得到更新后的训练样本集。
[0116]
在本技术实施例中,在确定出上述目标样本数据后,由于目标样本数据是缺失目标候选标签,因此未被标注目标候选标签,但出现了模型识别出的目标候选标签的标签概率大于预设概率,因此,计算机设备将目标样本数据修正为目标候选标签的正样本数据,得到更新后的训练样本集。
[0117]
具体地,计算机设备可以对训练样本集中的目标样本数据添加所述目标候选标签,得到更新后的训练样本集。由此,在模型训练过程中,为缺失标签的样本数据添加模型识别出的标签概率大于预设概率的候选标签,实现对样本数据的自动修正。
[0118]
步骤s350:基于所述训练样本集中每个样本数据对应的每个候选标签的标签概率,以及所述更新后的训练样本集中每个样本数据对应的标签,确定每个候选标签对应的正样本损失值,以及每个候选标签对应的负样本损失值。
[0119]
步骤s360:基于所述每个候选标签对应的正样本损失值,以及所述每个候选标签对应的负样本损失值,确定总损失值。
[0120]
在本技术实施例中,计算机设备在对上述目标样本数据更新为目标候选标签的正样本数据后,则其缺失的标签得到了添加,因此按照传统的损失值的确定方式,确定总损失值,可以保证计算的损失值的准确性,从而能够保证模型更新的准确性。
[0121]
其中,计算机设备可以基于每个样本数据对应的每个候选标签的标签概率,以及更新后的训练样本集中每个样本数据标注的标签,确定每个候选标签对应的正样本损失值,以及每个候选标签对应的负样本损失值;然后基于每个候选标签对应的正样本损失值,以及每个候选标签对应的负样本损失值,确定总损失值。
[0122]
在一些实施方式中,计算机设备可以按照前述实施例中的公式计算对于每个样本数据,确定其对应产生的损失值;然后再获取每个样本数据所产生的损失值的总和,从而得到总损失值。其中,和可以基于交叉熵损失函数、焦点损失函数、或者自动损失函数确定。
[0123]
步骤s370:根据所述总损失值,对所述初始识别模型进行迭代训练,得到训练后的识别模型。
[0124]
在本技术实施例中,步骤s370可以参阅前述实施例的内容,在此不再赘述。
[0125]
在一些实施方式中,由于模型训练过程中,对于任一样本数据而言,针对上述目标样本数据对应的未标注的目标候选标签的标签概率大于预设概率的情况,将其更新为目标候选标签的正样本数据,也就是说,是不断地对训练样本集更新,因此模型训练完成后,其
训练样本集也大概率标注完备,故计算机设备还可以将训练样本集上传至服务器,以供其他设备训练识别模型时进行使用。
[0126]
本技术实施例提供的模型训练方法,可以实现模型训练中,在每个训练epoch中,确定出标注不完整的负样本数据,并对其标注的标签进行更新,使得该负样本数据的标签准确,从而能够保证计算的损失值准确,进而提升训练得到的模型的准确性。
[0127]
请参阅图5,其示出了本技术实施例提供的一种模型训练装置400的结构框图。该模型训练装置400应用上述的计算机设备,该模型训练装置400包括:样本获取模块410、模型输入模块420、损失确定模块430以及迭代训练模块440。其中,所述样本获取模块410用于获取训练样本集,所述训练样本集包括多个样本数据,所述多个样本数据中包括至少一个候选标签对应的负样本数据;所述模型输入模块420用于分别将所述训练样本集中每个样本数据输入至初始识别模型,得到每个样本数据对应的每个候选标签的标签概率;所述损失确定模块430用于将所述训练样本集中的目标样本数据作为目标候选标签的正样本数据,并基于所述训练样本集中每个样本数据对应的每个候选标签的标签概率以及所述每个样本数据标注的标签,确定总损失值,所述目标样本数据未标注有所述目标候选标签,且所述目标样本数据对应所述目标候选标签的标签概率大于预设概率;所述迭代训练模块440用于根据所述总损失值,对所述初始识别模型进行迭代训练,得到训练后的识别模型。
[0128]
在一些实施方式中,损失确定模块430可以用于:基于所述每个样本数据标注的标签,确定所述每个候选标签的正样本数据以及负样本数据;针对每个候选标签,基于所述每个候选标签的每个正样本数据,以及所述每个正样本数据对应的候选标签的标签概率,确定每个候选标签对应的正样本损失值;针对每个候选标签,基于所述每个候选标签的每个负样本数据,以及所述每个负样本数据对应的候选标签的标签概率,确定每个候选标签对应的负样本损失值,其中,所述目标样本数据对应所述目标候选标签的负样本损失值按照正样本损失值的确定方式获取;基于所述每个候选标签对应的正样本损失值,以及所述每个候选标签对应的负样本损失值,确定总损失值。
[0129]
作为一种可能的实施方式,损失确定模块430针对每个候选标签,基于所述每个候选标签的每个负样本数据,以及所述每个负样本数据对应的候选标签的标签概率,确定每个候选标签对应的负样本损失值,可以包括:针对每个候选标签,若候选标签的负样本数据对应的标签概率小于或等于预设概率,则基于所述负样本数据对应所述候选标签的标签概率,并按照负样本损失值的确定方式获取损失值,作为所述负样本数据对应所述候选标签的负样本损失值;若候选标签的负样本数据对应的标签概率大于预设概率,则基于所述负样本数据对应所述候选标签的标签概率,并按照正样本损失值的确定方式获取损失值,作为所述负样本数据对应所述候选标签的负样本损失值。
[0130]
作为一种可能的实施方式,损失确定模块430针对每个候选标签,基于所述每个候选标签的每个正样本数据,以及所述每个正样本数据对应的候选标签的标签概率,确定每个候选标签对应的正样本损失值,包括:针对每个候选标签,基于所述候选标签的正样本数据对应所述候选标签的标签概率,并按照正样本损失值的确定方式获取损失值,作为所述正样本数据对应所述候选标签的正样本损失值。
[0131]
在一些实施方式中,损失确定模块430可以用于:获取所述训练样本集中的目标样本数据,所述目标样本数据未标注有所述目标候选标签,且所述目标样本数据对应所述目
标候选标签的标签概率大于预设概率;将所述目标样本数据修正为所述目标候选标签的正样本数据,得到更新后的训练样本集;基于所述训练样本集中每个样本数据对应的每个候选标签的标签概率,以及所述更新后的训练样本集中每个样本数据对应的标签,确定每个候选标签对应的正样本损失值,以及每个候选标签对应的负样本损失值;基于所述每个候选标签对应的正样本损失值,以及所述每个候选标签对应的负样本损失值,确定总损失值。
[0132]
作为一种可能的实施方式,损失确定模块430将所述目标样本数据修正为所述目标候选标签的正样本数据,得到更新后的训练样本集,可以包括:对所述训练样本集中的所述目标样本数据添加所述目标候选标签的标签,得到更新后的训练样本集。
[0133]
在一些实施方式中,所述正样本损失值以及所述负样本损失值基于交叉熵损失函数、焦点损失函数、或者自动损失函数确定。
[0134]
在一些实施方式中,迭代训练模块440可以具体用于:根据所述总损失值,调整所述初始识别模型的模型参数;返回所述分别将所述训练样本集中每个样本数据输入至初始识别模型,得到每个样本数据对应的每个候选标签的标签概率的步骤,直至满足训练结束条件,得到训练后的识别模型。
[0135]
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述装置和模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
[0136]
在本技术所提供的几个实施例中,模块相互之间的耦合可以是电性,机械或其它形式的耦合。
[0137]
另外,在本技术各个实施例中的各功能模块可以集成在一个处理模块中,也可以是各个模块单独物理存在,也可以两个或两个以上模块集成在一个模块中。上述集成的模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。
[0138]
综上所述,本技术提供的方案,通过获取训练样本集,分别将训练样本集中每个样本数据输入至初始识别模型,得到每个样本数据对应的每个候选标签的标签概率,然后将训练样本集中的目标样本数据作为目标候选标签的正样本数据,并基于训练样本集中每个样本数据对应的每个候选标签的标签概率以及每个样本数据标注的标签,确定总损失值,该目标样本数据未标注有所述目标候选标签,且目标样本数据对应目标候选标签的标签概率大于预设概率,再根据总损失值,对初始识别模型进行迭代训练,得到训练后的识别模型。由此,可以实现模型训练中,对于标注不完整的负样本数据,自动按照正样本计算损失,实现自动修正负样本所带来的损失,从而提升模型训练的效率和准确性。
[0139]
请参考图6,其示出了本技术实施例提供的一种电子设备的结构框图。该电子设备100可以是服务器、智能手机、平板电脑、智能手表、电子书、pc电脑、笔记本电脑等能够运行应用程序的电子设备。本技术中的电子设备100可以包括一个或多个如下部件:处理器110、存储器120、以及一个或多个应用程序,其中一个或多个应用程序可以被存储在存储器120中并被配置为由一个或多个处理器110执行,一个或多个应用程序配置用于执行如前述方法实施例所描述的方法。
[0140]
处理器110可以包括一个或者多个处理核。处理器110利用各种接口和线路连接整个电子设备100内的各个部分,通过运行或执行存储在存储器120内的指令、程序、代码集或指令集,以及调用存储在存储器120内的数据,执行电子设备100的各种功能和处理数据。可选地,处理器110可以采用数字信号处理(digital signal processing,dsp)、现场可编程
门阵列(field-programmable gate array,fpga)、可编程逻辑阵列(programmable logic array,pla)中的至少一种硬件形式来实现。处理器110可集成中央处理器(central processing unit,cpu)、图形处理器(graphics processing unit,gpu)和调制解调器等中的一种或几种的组合。其中,cpu主要处理操作系统、用户界面和应用程序等;gpu用于负责显示内容的渲染和绘制;调制解调器用于处理无线通信。可以理解的是,上述调制解调器也可以不集成到处理器110中,单独通过一块通信芯片进行实现。
[0141]
存储器120可以包括随机存储器(random access memory,ram),也可以包括只读存储器(read-only memory)。存储器120可用于存储指令、程序、代码、代码集或指令集。存储器120可包括存储程序区和存储数据区,其中,存储程序区可存储用于实现操作系统的指令、用于实现至少一个功能的指令(比如触控功能、声音播放功能、图像播放功能等)、用于实现下述各个方法实施例的指令等。存储数据区还可以存储电子设备100在使用中所创建的数据(比如电话本、音视频数据、聊天记录数据)等。
[0142]
请参考图7,其示出了本技术实施例提供的一种计算机可读存储介质的结构框图。该计算机可读介质800中存储有程序代码,所述程序代码可被处理器调用执行上述方法实施例中所描述的方法。
[0143]
计算机可读存储介质800可以是诸如闪存、eeprom(电可擦除可编程只读存储器)、eprom、硬盘或者rom之类的电子存储器。可选地,计算机可读存储介质800包括非易失性计算机可读介质(non-transitory computer-readable storage medium)。计算机可读存储介质800具有执行上述方法中的任何方法步骤的程序代码810的存储空间。这些程序代码可以从一个或者多个计算机程序产品中读出或者写入到这一个或者多个计算机程序产品中。程序代码810可以例如以适当形式进行压缩。
[0144]
最后应说明的是:以上实施例仅用以说明本技术的技术方案,而非对其限制;尽管参照前述实施例对本技术进行了详细的说明,本领域的普通技术人员当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不驱使相应技术方案的本质脱离本技术各实施例技术方案的精神和范围。
再多了解一些

本文用于企业家、创业者技术爱好者查询,结果仅供参考。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献