一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

一种联邦主动学习模型训练方法及系统

2022-09-14 23:45:38 来源:中国专利 TAG:


1.本发明涉及一种信息技术,特别涉及一种联邦主动学习模型训练方法及系统。


背景技术:

2.随着人工智能的快速发展,尤其是深度学习模型在各个领域中的广泛应用,极大的提高了社会生产效率。但是针对某领域想要得到一个性能好的深度学习模型的前提是拥有高质量、高数据量的数据集,数据的质量是保证深度学习模型泛化性能的前提。现今医疗领域、工业领域、金融领域或者其他领域中的诸如医院、工厂、互联网企业、证券公司、保险公司等每天都会产生许多没有标签的数据,且这些数据具有重复、价值密度低等特点。如何能够合理且高效地利用这些无标签地数据训练深度学习模型是一大现实难题。最直接的方法是将这些数据上传到高性能服务器中心并对这些数据进行人工标注,接着进行深度学习模型的训练。但是考虑到隐私保护、行业竞争、法律约束、知识产权保护、数据的存储和通讯等问题,集中各方数据非常困难,且对所有的数据进行人工标注也是一项巨大的开销。因此这些企业或工厂就像一座座数据孤岛,这一现象在医疗和金融行业更加明显。这极大的限制了相关领域的发展。
3.联邦学习是近年来提出的一种分布式机器学习范式。可以在各个参与方的私有数据不出本地的前提下,通过同态加密、差分隐私、安全多方计算等手段,各个参与方将本地模型参数上传到一个服务器上聚合出全局模型。在每一轮的训练中,各个参与方用自己本地的数据来训练本地模型,通过同态加密等方式对模型参数进行加密后传输给服务器。服务器根据相应的聚合算法(如fedavg算法)将各个参与方传来的模型参数进行聚合生成全局模型并将此时的全局模型广播到各个参与方,完成本轮训练。
4.主动学习是一种机器学习的一个子领域,在统计学领域也叫查询学习或最优实验设计。主动学习方法尝试解决样本的标注瓶颈,通过交互式地从尚未标注的样本池中主动挑选出一部分需要标注的样本子集由领域或行业专家进行人工标注。主动学习最主要的目标是挑选出高质量样本,在模型达到目标性能的前提下,尽可能的降低标注的成本。


技术实现要素:

5.针对数据孤岛和高质量数据标注问题,提出了一种联邦主动学习模型训练方法及系统,将联邦学习和主动学习有机结合,在保护数据隐私的前提下,充分利用各个参与方的未标注数据信息,并且尽可能地减少样本标注成本来提高模型的精度和泛化性能。
6.本发明的技术方案为:一种联邦主动学习模型训练方法,包括如下步骤:
7.1)各个参与方并行地从本地未标注数据集中随机挑选一部分样本请专家进行标注,将标注后的数据样本从本地数据集移动到已标注数据集d
l
中,剩下的未标注数据集为du,与此同时中央服务器随机生成全局模型参数w下发到各个参与方;
8.2)各个参与方并行地从中央服务器接收全局模型参数w加载到本地模型,利用d
l
对本地模型w
local
进行训练,利用本地模型w
local
中的特征嵌入作为输入,本地模型w
local
的损
失作为标签对损失预测模型w
loss
进行训练,学习训练中更新模型参数;
9.3)各个参与方并行地利用训练后损失预测模型w
loss
对du中所有数据样本执行一次模型推理,得到当前所有未标注数据的预测损失值pl;
10.4)各个参与方并行地利用训练后本地模型w
local
对d
l
和du中所有数据样本执行一次模型推理,抽取当前所有数据的潜在空间中的特征嵌入,并对所有已标注数据样本的特征嵌入进行平均计算得到其质心,接着计算得到所有未标注数据样本的特征嵌入到所有已标注数据样本的特征嵌入的质心的欧几里得距离dist;
11.5)各个参与方并行地利用步骤3)、4)得到的pl和dist加权计算每一个未标注数据样本的rank值,根据rank的大小降序排序,挑选出前b个样本由专家进行标注,并从du移动到d
l

12.6)各个参与方并行地上传训练更新后本地模型参数;
13.7)中央服务器接收各个参与方上传的本地模型参数利用模型聚合算法更新全局模型参数w并下发到各个参与方;
14.8)迭代训练,重复2)至7),直到满足迭代次数,获得最终的本地标注数据集和数据识别本地模型w
local

15.进一步,所述步骤2)损失预测模型w
loss
结构为:输入已标注数据样本在本地模型w
local
中的多个潜在空间中的特征嵌入,通过对数据样本的多个特征嵌入进行平均池化和全连接层的操作得到多个特征向量,接着将多个特征向量拼接在一起经过全连接层得到损失预测模型w
loss
的输出,即数据样本的损失预测。
16.进一步,所述步骤2)损失预测模型w
loss
和本地模型w
local
的训练中参数更新方法为:采用梯度

下降方式,采用均方误差损失函数,真实值为训练本地模型w
local
时的训练损失,更新w
loss
和w
local
参数的计算公式为:
[0017][0018][0019][0020]
其中t表示联邦学习第t轮训练,η为学习率,l
loss
为损失预测模型w
loss
训练集的损失函数,n为d
l
中的样本个数,predicted lossi为第i个样本的预测损失,train lossi为第i个样本在本地模型w
local
中的训练损失,l
local
为本地模型w
local
训练集的损失函数,根据任务的不同损失函数不同。
[0021]
进一步,所述步骤4)dist具体计算方法:第j个未标注数据样本的特征嵌入到所有已标注数据样本的特征嵌入的质心的计算公式为:
[0022][0023]
其中函数d(x,y)用于计算x和y的欧几里得距离,n表示d
l
中数据样本的个数,f(x)表示数据样本x的feature embedding,x
li
表示第i个已标注数据样本,x
uj
表述第j个未标注
的数据样本。
[0024]
进一步,所述步骤4)前b个样本的取值最多为初始未标注数据集中样本数量的百分之几。
[0025]
进一步,所述步骤7)聚合算法公式为:
[0026][0027]
其中m为参与方的个数,表示第i个参与方上传的本地模型参数。
[0028]
一种联邦主动学习模型训练系统,包括中央服务器以及若干个参与方设备,所述参与方设备保存有大量的未标注数据集,每个参与方与有相关的领域专家通讯,领域专家对数据样本进行高质量的标注;所述中央服务器中包括全局模型,全局模型根据上传的各个参与方的本地模型参数进行聚合更新;本地各个参与方设备各带本地数据识别模型,本地数据识别模型在中央服务器下发的模型参数和领域专家识别标注下,进行训练,用于对本地数据进行标注识别。
[0029]
本发明的有益效果在于:本发明联邦主动学习模型训练方法及系统,将联邦学习和主动学习有机的结合在一起,在每个联邦学习训练轮次中,利用联邦学习得到的全局模型来指导本地对数据进行标注,提高标注效率,一定程度克服数据孤岛问题;在保护数据隐私的前提下能够充分利用各参与方的大量未标注数据进行模型的协同训练;在主动学习中提出基于损失预测和特征空间距离的混合采样策略,可以挑选出高质量样本,降低标注的成本。
附图说明
[0030]
图1为本发明联邦主动学习模型训练方法的流程框图;
[0031]
图2为本发明联邦主动学习模型训练系统的架构示意图;
[0032]
图3为本发明联邦主动学习模型训练方法及系统中样本采样策略示意图。
具体实施方式
[0033]
下面结合附图和具体实施例对本发明进行详细说明。本实施例以本发明技术方案为前提进行实施,给出了详细的实施方式和具体的操作过程,但本发明的保护范围不限于下述的实施例。
[0034]
如图1、2所示联邦主动学习模型训练方法的流程框图和系统框图,包括如下步骤:
[0035]
s100、各个参与方并行地从本地未标注数据集中随机挑选一部分样本请专家进行标注,将标注后的数据样本从本地数据集移动到已标注数据集d
l
中,剩下的未标注数据集为du,与此同时中央服务器随机生成全局模型参数w下发到各个参与方;
[0036]
具体的,每个参与方本地都会有两个数据集du和d
l
,du中保存未标注数据样本,d
l
中保存已标注数据样本,初始时d
l
为空数据集。
[0037]
s200、各个参与方并行地从中央服务器接收全局模型参数w加载到本地模型,利用d
l
对本地模型w
local
进行训练,利用本地模型w
local
中的特征嵌入作为输入,本地模型w
local
的损失作为标签对损失预测模型w
loss
进行训练;
[0038]
具体的,本地模型w
local
输入已标注数据集d
l
,采用获得的全局模型参数w,对标签特征进行学习训练。损失预测模型w
loss
的输入为已标注数据样本在本地模型w
local
中的多个潜在空间中的特征嵌入feature embedding,通过对数据样本的多个feature embedding进行平均池化和全连接层的操作得到多个特征向量,接着将多个特征向量拼接在一起经过全连接层得到损失预测模型w
loss
的输出,即数据样本的损失预测。损失预测模型w
loss
的训练方式采用梯度下降方式,采用均方误差损失函数,真实值为训练本地模型w
local
时的训练损失,更新w
loss
和w
local
参数的计算公式为:
[0039][0040][0041][0042]
其中t表示联邦学习第t轮训练,η为学习率,l
loss
为损失预测模型w
loss
训练集的损失函数,n为d
l
中的样本个数,predicted lossi为第i个样本的预测损失,train lossi为第i个样本在本地模型w
local
中的训练损失,l
local
为本地模型w
local
训练集的损失函数,根据任务的不同损失函数不同。
[0043]
此步骤的有益之处:在每轮联邦学习之中,本地训练之前本地模型会载入当前全局模型参数,用已标注数据样本在本地模型w
local
中的多个潜在空间中的特征嵌入feature embedding作为损失预测模型w
loss
的输入,并利用对本地模型w
local
在已标注数据样本上进行训练产生的损失值train lossi作为标签对损失预测模型w
loss
进行训练,接下来的步骤中会根据损失预测模型w
loss
推理未标注数据样本输出值作为样本采样依据,从而实现利用联邦学习模型学到的信息和本地已标注数据样本的信息来指导主动学习的样本采样。
[0044]
s300、各个参与方并行地利用训练后损失预测模型w
loss
对du中所有数据样本执行一次模型推理,得到当前所有未标注数据的预测损失值pl;
[0045]
s400、如图3所示,各个参与方并行地利用训练后本地模型w
local
对d
l
和du中所有数据样本执行一次模型推理,抽取当前所有数据的潜在空间中的特征嵌入feature embedding,并对所有已标注数据样本的feature embedding进行平均计算得到其质心,接着计算得到所有未标注数据样本的feature embedding到所有已标注数据样本的feature embedding的质心的欧几里得距离dist;
[0046]
具体的,第j个未标注数据样本的feature embedding到所有已标注数据样本的feature embedding的质心的计算公式为:
[0047][0048]
其中函数d(x,y)用于计算x和y的欧几里得距离,n表示d
l
中数据样本的个数,f(x)表示数据样本x的feature embedding,x
li
表示第i个已标注数据样本,x
uj
表述第j个未标注的数据样本。
[0049]
此步骤的有益之处:同s200一样,接下来的步骤中会根据计算得到的空间距离作
为样本采样依据,也利用了联邦学习模型学到的信息和本地已标注数据样本信息来指导主动学习的样本采样。
[0050]
s500、各个参与方并行地利用得到的pl和dist加权计算每一个未标注数据样本的rank值,根据rank的大小降序排序,挑选出前b个样本由专家进行标注,并从du移动到d
l

[0051]
具体的,某个未标注数据样本的rank值计算公式为:
[0052]
ranku=αplu βdistu[0053]
其中α和β为权重系数,α=1/2,β=1/2,plu为某个未标注数据样本的预测损失值,distu为某个未标注数据样本到所有已标注数据样本的feature embedding的质心的距离。
[0054]
具体的,前b个样本的取值为远小于本地未标注数据集中样本的个数,最多为初始未标注数据集中样本数量的百分之几,挑选出高质量样本数据。
[0055]
此步骤的有益之处:利用损失预测值和空间距离作为主动学习采样的量化指标,当未标注的数据样本的经过损失预测模型w
loss
推理得到的损失预测值pl越大,说明该样本包含的信息量越多,因为损失预测模型w
loss
是由已标注的数据样本训练的,若某样本的损失预测值越小,说明该样本和已标注的数据样本相似,对于模型的训练几乎没有帮助,同理,当未标注的数据样本的feature embedding与所有已标注数据样本的feature embedding的质心的欧几里得距离dist越大,说明该样本包含的信息量越多。本发明综合考量损失预测和空间距离两方面因素,可以有效地采样到信息性样本。
[0056]
s600、各个参与方并行地上传训练更新后本地模型参数;
[0057]
s700、中央服务器接收各个参与方上传的本地模型参数利用模型聚合算法更新全局模型参数w并下发到各个参与方;
[0058]
具体的,全局模型的聚合算法公式为:
[0059][0060]
其中m为参与方的个数,表示第i个参与方上传的本地模型参数。
[0061]
s800、迭代训练,重复s200-s700,直到满足迭代次数,获得最终的本地标注数据集和数据识别本地模型w
local

[0062]
在本实施方式中,如图2所示,基于上述方法实现的一种联邦学习模型训练系统,包括中央服务器以及若干个参与方设备,且所述参与方设备保存有大量的未标注数据集,此外每个参与方都会有相关的领域专家能够对数据样本进行高质量的标注工作,中央服务器中包括全局模型,全局模型根据上传的各个参与方的本地模型参数进行聚合更新;本地各个参与方设备各带本地数据识别模型,本地数据识别模型在中央服务器下发的模型参数和领域专家识别标注下,进行训练,用于对本地数据进行标注识别。
[0063]
以上所述实施例仅表达了本发明的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本发明构思的前提下,还可以做出若干变形和改进,这些都属于本发明的保护范围。因此,本发明专利的保护范围应以所附权利要求为准。
再多了解一些

本文用于企业家、创业者技术爱好者查询,结果仅供参考。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献