一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

一种迁移学习训练方法、装置、电子设备及存储介质与流程

2022-02-22 17:38:15 来源:中国专利 TAG:


1.本公开涉及数据处理技术领域,尤其涉及深度学习技术领域。


背景技术:

2.现有的迁移学习训练方法是直接利用源域样本的数据与目标域样本的数据对目标域模型进行联合训练,但源域的场景的数据分布与当前场景也就是目标域的场景的数据分布并不一致,甚至有较大的差异性,直接进行全量的联合训练会对当前场景的模型产生负迁移现象,不但无法提升反而会降低当前场景的模型的性能,而抽样的联合训练控制了引入的源域样本的数量,但由于并没有解决源域样本数据分布与当前场景数据分布的差异性,只能从一定程度上减轻负迁移现象,同样很难改变数据差异对当前场景模型的影响,而利用迁移学习效果较好的网络结构,虽然能够有效地减缓负迁移现象,但是针对不同的业务场景需要耗费大量的人力成本去调整网络结构的模型参数。


技术实现要素:

3.本公开提供了一种迁移学习训练方法、装置、电子设备及存储介质。
4.根据本公开的一方面,提供了一种迁移学习训练方法,包括:
5.获取源域样本,所述源域样本中包含多个源域数据和与源域数据对应的标签值;
6.利用一阶段模型计算每个源域数据的第一交叉熵并根据第一交叉熵计算相似度权重;
7.获取目标域样本,所述目标域样本中包含多个目标域数据和与目标域数据对应的标签值;
8.利用二阶段模型计算每个源域数据和每个目标域数据的第二交叉熵;
9.根据每个源域数据和每个目标域数据的第二交叉熵与相似度权重计算每个源域数据和每个目标域数据的第三交叉熵;
10.根据每个源域数据和每个目标域数据的第三交叉熵对所述二阶段模型进行参数更新;
11.将参数更新后的二阶段模型对业务数据进行预估或排序。
12.根据本公开的另一方面,提供了一种迁移学习训练装置,包括:
13.采集模块,用于获取源域样本,所述源域样本中包含多个源域数据和与源域数据对应的标签值;
14.计算模块,用于利用一阶段模型计算每个源域数据的第一交叉熵并根据第一交叉熵计算相似度权重;
15.所述采集样本,还用于获取目标域样本,所述目标域样本中包含多个目标域数据和与目标域数据对应的标签值;
16.所述计算模块,还用于利用二阶段模型计算每个源域数据和每个目标域数据的第二交叉熵;
17.所述计算模块,还用于根据每个源域数据和每个目标域数据的第二交叉熵与相似度权重计算每个源域数据和每个目标域数据的第三交叉熵;
18.训练模块,用于根据每个源域数据和每个目标域数据的第三交叉熵对所述二阶段模型进行参数更新;
19.处理模块,用于将参数更新后的二阶段模型对业务数据进行预估或排序。
20.根据本公开的另一方面,提供了一种电子设备,包括:
21.至少一个处理器;以及
22.与所述至少一个处理器通信连接的存储器;其中,
23.所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行上述任一项所述的方法。
24.根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行上述任一项所述的方法。
25.根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现上述任一项所述的方法。
26.应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
27.附图用于更好地理解本方案,不构成对本公开的限定。其中:
28.图1是根据本公开实施例提供的迁移学习训练方法的流程示意图;
29.图2是根据本公开实施例提供的迁移学习训练装置的结构示意图;
30.图3是用来实现本公开实施例的迁移学习训练方法的电子设备的框图。
具体实施方式
31.以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
32.为了在不耗费大量人力成本的前提下减缓负迁移效果,提高迁移学习效果的稳定性,如图1所示,本公开一实施例提供了一种迁移学习训练方法,该方法包括:
33.步骤101,获取源域样本,所述源域样本中包含多个源域数据和与源域数据对应的标签值。
34.获取源域样本,源域样本中包含多个源域数据和与源域数据对应的标签值,与源域数据对应的标签值为0或者1,业务场景规模较小或者建立时间较短的推荐系统(即目标域)中的模型(即目标域模型,在本实施例中为二阶段模型)面临着训练数据较少而导致模型排序预估精度不高,迁移学习是通过从数据较多的源域中获取源域样本来给目标域模型进行训练从而提升目标域模型的排序预估精度。
35.步骤102,利用一阶段模型计算每个源域数据的第一交叉熵并根据第一交叉熵计算相似度权重。
36.利用一阶段模型计算每个源域数据的第一交叉熵,具体可根据以下公式计算第一交叉熵l
1i

[0037][0038]
其中,y
1i
为第i个源域数据的标签值,为第i个源域数据的预估值;
[0039]
计算出源域数据的第一交叉熵后,可根据以下公式计算相似度权重w
1i

[0040][0041]
其中,e为自然常数。
[0042]
步骤103,获取目标域样本,所述目标域样本中包含多个目标域数据和与目标域数据对应的标签值。
[0043]
获取目标域样本,目标域样本中包含多个目标域数据和与目标域数据对应的标签值,与目标域数据对应的标签值为0或者1。
[0044]
步骤104,利用二阶段模型计算每个源域数据和每个目标域数据的第二交叉熵。
[0045]
根据以下公式计算每个源域数据的第二交叉熵l
2i

[0046][0047]
其中,y
2i
为第i个源域数据的标签值,为第i个源域数据的预估值;
[0048]
根据以下公式计算每个目标域数据的第二交叉熵l
2j

[0049][0050]
其中,y
2j
为第j个目标域数据的标签值,为第j个目标域数据的预估值;
[0051]
因为二阶段模型对于源域数据进行预估得到的预估值可能与一阶段模型的不同,所以这里也需要重新计算源域数据的第二交叉熵。
[0052]
步骤105,根据每个源域数据和每个目标域数据的第二交叉熵与相似度权重计算每个源域数据和每个目标域数据的第三交叉熵。
[0053]
再根据每个源域数据和每个目标域数据的第二交叉熵与每个源域数据的相似度权重计算每个源域数据的第三交叉熵l
3i
和每个目标域数据的第三交叉熵l
3j

[0054]
l
3i
=w
1i
l
2i
[0055]
l
3j
=l
2j
[0056]
其中,w
1i
为第i个源域数据的相似度权重,l
2i
为第i个源域数据的第二交叉熵,l
2j
为第j个目标域数据的第二交叉熵。
[0057]
步骤106,根据每个源域数据和每个目标域数据的第三交叉熵对所述二阶段模型进行参数更新。
[0058]
步骤107,将参数更新后的二阶段模型对业务数据进行预估或排序。
[0059]
利用源域的源域样本与目标域的目标域样本对二阶段模型进行训练,使得业务场
景规模较小或者建立时间较短的目标域中的模型不会因为训练数据较少而导致模型排序预估精度不高,而利用一阶段模型对与源域样本进行预估,根据预估值和标签值计算源域样本的第一交叉熵,并根据第一交叉熵进一步计算出源域样本的相似度权重,在训练二阶段模型时,先计算每个源域数据和每个目标域数据的第二交叉熵,再根据相似度权重和第二交叉熵计算第三交叉熵,最后根据第三交叉熵对二阶段模型进行参数更新,能够有效减缓因为源域与目标域的差异性而导致的负迁移现象,相比现有技术只是对源域样本进行筛选后对二阶段模型进行训练的方法,减缓负迁移现象的效果更好,并且不需要利用特定的网络结构,不需要针对不同的业务场景而耗费巨大的人力成本去调整网络结构的模型参数。
[0060]
在步骤102中,利用一阶段模型计算每个源域数据的第一交叉熵并根据第一交叉熵计算相似度权重,在一可实施方式中,利用一阶段模型提取每个源域数据在预设n个维度上的特征数据;
[0061]
根据每个源域数据中在预设n个维度上的特征数据对该源域数据进行预估,得到每个源域数据对应的预估值;
[0062]
根据每个源域数据的标签值和预估值计算该源域数据的第一交叉熵;
[0063]
根据每个源域数据的第一交叉熵计算该源域数据的相似度权重。
[0064]
利用一阶段模型提取每个源域数据在预设n个维度上的特征数据,预设n个维度可以包括用户特征、文章特征、行为序列特征、请求特征和其他业务相关特征等,具体可以根据源域和目标域的场景类型进行设置,一阶段模型根据源域数据在这些多个维度上的特征数据进行预估,可以提高一阶段模型对源域数据预估的准确度,在得到预估值后根据每个源域数据的标签值和预估值计算该源域数据的第一交叉熵,再根据每个源域数据的第一交叉熵计算该源域数据的相似度权重。
[0065]
在步骤104中,利用二阶段模型计算每个源域数据和每个目标域数据的第二交叉熵,在一可实施方式中,利用二阶段模型提取每个源域数据和每个目标域数据在预设m个维度上的特征数据,所述预设m个维度比预设n个维度多出l个维度,所述l个维度上的特征数据能够表征源域数据和目标域数据的差异性;
[0066]
根据每个源域数据和每个目标域数据在预设m个维度上的特征数据对每个源域数据和每个目标域数据进行预估,得到每个源域数据和每个目标域数据的预估值;
[0067]
根据每个源域数据和每个目标域数据的标签值与预估值计算每个源域数据和每个目标域数据的第二交叉熵。
[0068]
利用二阶段模型提取每个源域数据和每个目标域数据在预设m个维度上的特征数据,所述预设m个维度比预设n个维度多出l个维度,所述l个维度上的特征数据能够表征源域数据和目标域数据的差异性,预设m个维度比预设n个维度多出的l个维度可以是app特征数据、场景特征数据等能够显著地体现出源域跟目标域之间差异性的特征数据,在二阶段模型对每个源域数据和每个目标域数据进行预估时,这些特征数据能够使得二阶段模型清楚地区分出差异性较大的源域数据,更准确地对这些数据进行预估,最终得到的交叉熵也更准确,在提取出特征数据后,根据每个源域数据和每个目标域数据在预设m个维度上的特征数据对每个源域数据和每个目标域数据进行预估,得到每个源域数据和每个目标域数据的预估值,根据每个源域数据和每个目标域数据的标签值与预估值计算每个源域数据和每
个目标域数据的第二交叉熵。
[0069]
在步骤102中,利用一阶段模型计算每个源域数据的第一交叉熵之后,在一可实施方式中,利用所述一阶段模型计算每个源域数据的第一交叉熵之后不对所述一阶段模型进行参数更新。
[0070]
在利用所述一阶段模型计算每个源域数据的第一交叉熵之后不对所述一阶段模型进行参数更新,可以使一阶段模型不会被差异性较大的源域数据所影响,导致一阶段模型本身产生负迁移效果,防止一阶段模型对源域数据的预估能力下降而导致计算的源域数据的相似度权重不准确。
[0071]
在步骤102中,利用一阶段模型计算每个源域数据的第一交叉熵并根据第一交叉熵计算相似度权重之后,在一可实施方式中,将源域样本中相似度权重小于等于预设阈值的源域数据剔除。
[0072]
将源域样本中相似度权重小于等于预设阈值的源域数据剔除,相当于根据相似度权重对源域样本进行筛选,剔除掉源域样本中相似度权重较小也就是与目标域差异性较大的源域数据,能够使得最终应用于训练二阶段模型的源域数据与目标域的差异性在一定可控范围内,更有效地减缓负迁移现象的产生,进一步提高迁移学习的效果。
[0073]
本公开一实施例提供了一种迁移学习训练装置,如图2所示,该装置包括:
[0074]
采集模块10,用于获取源域样本,所述源域样本中包含多个源域数据和与源域数据对应的标签值;
[0075]
计算模块20,用于利用一阶段模型计算每个源域数据的第一交叉熵并根据第一交叉熵计算相似度权重;
[0076]
所述采集样本10,还用于获取目标域样本,所述目标域样本中包含多个目标域数据和与目标域数据对应的标签值;
[0077]
所述计算模块20,还用于利用二阶段模型计算每个源域数据和每个目标域数据的第二交叉熵;
[0078]
所述计算模块20,还用于根据每个源域数据和每个目标域数据的第二交叉熵与相似度权重计算每个源域数据和每个目标域数据的第三交叉熵;
[0079]
训练模块30,用于根据每个源域数据和每个目标域数据的第三交叉熵对所述二阶段模型进行参数更新;
[0080]
处理模块40,用于将参数更新后的二阶段模型对业务数据进行预估或排序。
[0081]
其中,所述计算模块20,还用于利用一阶段模型提取每个源域数据在预设n个维度上的特征数据;
[0082]
所述计算模块20,还用于根据每个源域数据中在预设n个维度上的特征数据对该源域数据进行预估,得到每个源域数据对应的预估值;
[0083]
所述计算模块20,还用于根据每个源域数据的标签值与预估值计算该源域数据的第一交叉熵;
[0084]
所述计算模块20,还用于根据每个源域数据的第一交叉熵计算该源域数据的相似度权重。
[0085]
其中,所述计算模块20,还用于利用二阶段模型提取每个源域数据和每个目标域数据在预设m个维度上的特征数据,所述预设m个维度比预设n个维度多出l个维度,所述l个
维度上的特征数据能够表征源域数据和目标域数据的差异性;
[0086]
所述计算模块20,还用于根据每个源域数据和每个目标域数据在预设m个维度上的特征数据对每个源域数据和每个目标域数据进行预估,得到每个源域数据和每个目标域数据的预估值;
[0087]
所述计算模块20,还用于根据每个源域数据和每个目标域数据的标签值与预估值计算每个源域数据和每个目标域数据的第二交叉熵。
[0088]
其中,所述训练模块30,还用于利用所述一阶段模型计算每个源域数据的第一交叉熵之后不对所述一阶段模型进行参数更新。
[0089]
其中,所述计算模块20,还用于将源域样本中相似度权重小于等于预设阈值的源域数据剔除。
[0090]
本公开的技术方案中,所涉及的用户个人信息的获取,存储和应用等,均符合相关法律法规的规定,且不违背公序良俗。
[0091]
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
[0092]
图3示出了可以用来实施本公开的实施例的示例电子设备300的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
[0093]
如图3所示,设备300包括计算单元301,其可以根据存储在只读存储器(rom)302中的计算机程序或者从存储单元308加载到随机访问存储器(ram)303中的计算机程序,来执行各种适当的动作和处理。在ram303中,还可存储设备300操作所需的各种程序和数据。计算单元301、rom302以及ram303通过总线304彼此相连。输入/输出(i/o)接口305也连接至总线304。
[0094]
设备300中的多个部件连接至i/o接口305,包括:输入单元306,例如键盘、鼠标等;输出单元307,例如各种类型的显示器、扬声器等;存储单元308,例如磁盘、光盘等;以及通信单元309,例如网卡、调制解调器、无线通信收发机等。通信单元309允许设备300通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
[0095]
计算单元301可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元301的一些示例包括但不限于中央处理单元(cpu)、图形处理单元(gpu)、各种专用的人工智能(ai)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(dsp)、以及任何适当的处理器、控制器、微控制器等。计算单元301执行上文所描述的各个方法和处理,例如迁移学习训练方法。例如,在一些实施例中,迁移学习训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元308。在一些实施例中,计算机程序的部分或者全部可以经由rom302和/或通信单元309而被载入和/或安装到设备300上。当计算机程序加载到ram303并由计算单元301执行时,可以执行上文描述的迁移学习训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元301可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行迁移学习训练方法。
[0096]
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(fpga)、专用集成电路(asic)、专用标准产品(assp)、芯片上系统的系统(soc)、负载可编程逻辑设备(cpld)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
[0097]
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
[0098]
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(ram)、只读存储器(rom)、可擦除可编程只读存储器(eprom或快闪存储器)、光纤、便捷式紧凑盘只读存储器(cd-rom)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
[0099]
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,crt(阴极射线管)或者lcd(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
[0100]
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(lan)、广域网(wan)和互联网。
[0101]
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式系统的服务器,或者是结合了区块链的服务器。
[0102]
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例
如,本发公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
[0103]
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
再多了解一些

本文用于企业家、创业者技术爱好者查询,结果仅供参考。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献