一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

一种基于异质客户端感知的联邦主动学习方法与流程

2023-02-06 17:25:45 来源:中国专利 TAG:

1.本发明涉及一种机器学习领域中的联邦主动学习方法,尤其涉及一种基于客户端异质感知采样的联邦主动学习方法。


背景技术:

2.主动学习是通过识别信息量最大的无标签数据,并从人类专家那里查询其标签来学习一个模型,以解决机器学习中由于高昂的标注成本导致标签数据不足下难以训练高质量模型的问题。然而,由于人力和无标注语料库大小的限制,每个数据所有者单独执行主动学习往往不足以获得可接受的模型精度。而分布式的主动学习利用多个用户合作标注数据和训练模型解决了上述问题。但在这种分布式范式中,原始数据被直接暴露在用户间,导致用户对数据隐私的担忧。
3.联邦主动学习将主动学习扩展至有多个客户端和一个服务器的联邦学习中。具体来说,每个客户端标注自己的无标签数据,在本地利用标签数据进行模型训练,并通过与服务器之间关于模型参数而非原始数据的多轮通信来学习一个共享的全局模型,以此克服了数据大小和人力的限制,并减轻了对数据隐私的担忧。
4.当前的联邦主动学习是将传统主动学习方法以联邦的形式进行部署,例如:基于不确定度、多样性和预期模型变化的数据采集策略。当客户端数据服从独立同分布的情况下,这些方法能维持较好的表现;但是实际客户端数据往往是服从非独立同分布的,此时客户端的模型训练和采样容易收到其他人的干扰。这些现存的策略往往忽略异质客户端所引起的模型认知变化问题,从而挑选低信息量的样本,导致联邦主动学习的失败。


技术实现要素:

5.本发明的目的在于针对现有技术的不足,提供一种基于客户端异质感知采样的联邦主动学习方法,解决客户端数据服从非独立同分布时,联邦主动学习的性能优化,实现客户端标注预算有限下,高效的数据挑选和高质量的模型训练。
6.本发明所采用的技术方案如下:一种基于异质客户端感知的联邦主动学习方法,包括以下步骤:
7.(1)客户端在自身私有数据上进行本地模型训练;每完成一次本地训练会对每个无标签样本进行推断,并记录模型的预测的一致性结果,用于认知波动的捕获;
8.(2)客户端k完成本地训练后,上传本地更新至服务器;服务器按下式执行聚合获得新一轮的全局模型ωr,并下发给所有客户端;
[0009][0010]
其中,n表示所有客户端的总标签集大小,表示客户端k本地训练的标签集大小;
[0011]
(3)客户端收到新的全局模型开始下一轮本地训练前,客户端基于步骤(1)记录的信息统计连续轮内模型预测是否一致并计算累计变化量
[0012]
(4)客户端根据累计变化量的大小和标注预算贪心地从无标签集中挑选出数据并标注记为获得更新的标签集和无标签集公式如下:
[0013][0014][0015]
(5)采样完成后,客户端利用捕获的认知信息暂时得将模型过度自信或相对简单的零波动样本移到休眠集获得更新后的休眠集和无标签集;表示为如下公式:
[0016][0017][0018]
只有当无标签集的规模小于给定的唤醒阈值时,将从中随机唤醒ta部分冻结数据到无标签池中;
[0019]
(6)完整联邦主动学习重复步骤(1)至步骤(5),直至满足指定轮或者性能阈值;其中,在每次局部训练之前,除了第一轮,每个客户端k都会从无标签集中随机采样一个大小为的子集代替参与后续进程。
[0020]
本发明具有的有益效果是:
[0021]
1、本发明提出了一种新颖有效的基于客户端异质感知采样的联邦主动学习方法,其发现与传统的主动学习不同,由于异质客户之间的聚合操作导致模型认知频繁的发生变化。
[0022]
2、本发明利用模型的认知变化来衡量客户端的异质性,以指导高信息的样本的选择;为了避免浪费预算,本发明基于认知变化设计了一个对齐损失来校准模型的决策边界。
[0023]
3、本发明设计了一个带有子集采样的数据冻结和唤醒机制实现计算性能的优化,本质上是利用认知波动暂时地将对模型来说过于自信和相对的无标签的数据排除在推理之外。
[0024]
4、大量实验表明,在主动学习和联邦学习典型的图像分类数据集上,本发明与当前的基线算法相比有明显的改进,特别是当不同的客户端之间,数据分布和标注行为都变化时。
具体实施方式
[0025]
现结合具体实施和示例对本发明的技术方案作进一步说明。
[0026]
本发明具体实施例及其实施过程如下:
[0027]
步骤1:客户端在自身私有数据上进行本地模型训练;每完成一次本地训练会对每个无标签样本进行推断,并记录模型的预测的一致性结果,用于认知波动的捕获。具体操作为:
[0028]
(1)首先,用表示客户端k在其标签集上连续更新e轮后的模型检查点。相应地,表示e轮训练后样本xi的预测标签。其中,表示xi在最后一层softmax激活后的输出,其中c表示某一类别。
[0029]
(2)其次,对于每一个客户端k在大小为的无标签集上做一次推断,并记录预测结果当xi的连续两次推断结果不一致时,记本地模型产生了一次认知变化用一个e维的向量ev来记录e回合内每个样本的历史认知变化。因此,在第r个通信轮,客户端k对于样本xi的认知波动可以按如下公式计算:
[0030][0031][0032]
其中,是一个指标函数,而第一个
[0033]
步骤1中,首轮客户端本地训练的损失是基于交叉熵的分类损失;其余轮为增加了对齐损失项的新损失函数。其中利用对齐损失项校准模型决策边界,具体为:
[0034]
(1)第r轮,客户端k根据无标签样本的认知变化是否大于平均的认知变化将无标签集分为两类(类别j∈{0,1}),公式如下:
[0035][0036][0037][0038]
(2)我们将当前模型中样本xi的特征(softmax层之前的输出)表示为的特征(softmax层之前的输出)表示为因此,我们可以用余弦相似度cos(
·
)来计算计算样本xi在当前模型上的特征输出与在训练好的局部模型上的特征输出之间的距离d
loc
,以及与训练好的全局模型ω
r-1
中的特征输出的差距d
glo
,可以表示为如下公式:
[0039][0040][0041]
(3)最后,我们在下列公式中定义了对齐损失,将小ev样本的模型的决策边界与局部模型对齐,而将大ev样本的决策边界靠近为全局模型。
[0042][0043]
其中τ表示一个温度参数。此外,如果那么d
*
(xi)表示d
loc
(xi),否则表示dglo
(xi)。对于每次随机梯度下降sgd,我们从无标签集中随机抽取数据来计算对齐损失,批量大小等于当前训练的标签数据。因此,更新后的损失函数如下。
[0044][0045]
其中μ是一个超参数,用于控制对齐损失的权重,其中l
class
表示基础的cross-entropy分类损失。因此,客户端目标函数转换为如下:
[0046][0047]
步骤2:客户端k完成本地训练后,上传本地更新至服务器;服务器按下式执行聚合获得新一轮的全局模型ωr,并下发给所有客户端。
[0048][0049]
其中n表示所有客户端的总标签集大小,表示客户端k本地训练的标签集大小。
[0050]
步骤3:客户端收到新的全局模型开始下一轮本地训练前,客户端基于步骤(1)记录的信息统计连续轮内模型预测是否一致并计算累计变化量
[0051]
步骤4:客户端根据累计变化量的大小和标注预算贪心地从无标签集中挑选出数据并标注记为获得更新的标签集和无标签集公式如下:
[0052][0053][0054]
步骤5:采样完成后,客户端利用捕获的认知信息暂时得将模型过度自信或相对简单的零波动样本移到休眠集获得更新后的休眠集和无标签集。可以表示为如下公式:
[0055][0056][0057]
只有当无标签集的规模小于给定的唤醒阈值时,我们将从中随机唤醒ta部分冻结数据到无标签池中,如和
[0058]
零波动的产生有两个原因:其一,本地模型过于自信,始终给出相同的预测。其二,样本相对简单,不需要浪费注解预算。我们的方法让模型暂时不看这些过于自信的样本,直到未标注集用完,以减少时间消耗,缓解模型的过度自信。
[0059]
步骤6:完整联邦主动学习重复步骤(1)至步骤(5),直至满足指定轮或者性能阈值。其中,在每次局部训练之前,除了第一轮,每个客户端k都会从无标签集中随机采样一个大小为的子集代替参与后续进程。
再多了解一些

本文用于创业者技术爱好者查询,仅供学习研究,如用于商业用途,请联系技术所有人。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献