一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

一种面向长尾异构数据的联邦学习方法

2022-05-06 06:15:06 来源:中国专利 TAG:

loss去减弱不平衡带来的影响。但是这个方法的性能随着数据异构程度的加深而急剧下降。


技术实现要素:

8.为解决现有技术的不足,实现在满足用户隐私保护、数据安全的同时,提升联邦学习下模型性能,从而提高图像识别效率的目的,本发明采用如下的技术方案:
9.一种面向长尾异构数据的联邦学习方法,包括如下步骤:
10.s1,服务器端随机初始化全局模型w,并将模型参数下发至各个客户端,各个客户端利用收到的模型参数进行本地模型更新,并将更新后的本地模型参数上传至服务器端;
11.s2,服务器端对收到的本地模型参数后进行聚合,得到教师模型和学生模型;
12.s3,服务器端对教师模型进行校准,使教师模型在无偏知识上进行学习,以此教出好的学生模型;
13.s4,通过知识蒸馏,将教师模型的无偏知识传递给学生模型,随后将学生模型下发至各个客户端开始下一轮联邦训练。
14.进一步地,步骤s1中,服务器端初始化全局模型参数w,随机选择参与本轮训练的客户端集合s,并将模型参数广播给参与本轮训练的客户端集合s,s中的每个客户端,均利用收到的全局模型参数w和本地的数据,执行随机梯度下降(sgd),以更新本地模型,客户端k更新得到的本地模型参数为wk,待更新之后,各个客户端将其更新的模型参数发还给服务器端。
15.进一步地,步骤s2包括如下步骤:
16.s21,服务器端对本地模型参数进行平均加权,得到学生模型,计算公式如下:
[0017][0018]
φs(x)=φw(x)
ꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀ
(公式2)
[0019]
其中,|dk|表示第k个客户端拥有的图像数据量,|d|表示所有客户端拥有的图像数据总量,k表示客户端数量,x表示输入图像数据,φw(
·
)表示联邦平均模型的网络,φs(
·
)表示学生模型的网络;
[0020]
s22,服务器端对本地模型参数进行加权聚合,得到教师模型,计算公式如下:
[0021][0022]
其中,φ
t
(
·
)表示教师模型的网络,ek表示赋给客户端k的权重,表示第k个客户端的网络。
[0023]
进一步地,步骤s3中,由于本地模型是在具有不同分布的本地数据上进行训练的,每个本地模型在尾部类上的表现可能不同,因此我们要为在尾部类表现更好的本地模型分配更高的权重,然而,服务器端不知道哪些图像的类是尾部类,并且哪些客户端本地模型在上面表现良好,因此,我们不是给每个客户端一个固定的权重,相反,我们提出基于客户端的权重分配策略,以此来计算每个客户端本地模型的权重ek,最后将ek归一化使其总和等于1,即为最终权重,权重ek的计算公式如下:
[0024]
[0025]
其中,ae∈rc和be表示可被学习的网络参数,rc表示c维向量,t为转置符号,基于客户端的校准就像自注意力机制一样,根据模型的原始输出logits对本地模型计算权重,再将权重乘回原始输出logits。
[0026]
进一步地,步骤s3中,若没有一个客户端本地模型可以很好地处理尾类,那么加权集成得到的教师模型仍偏向于头部类,为解决该问题,我们提出基于类的原始输出logits校准策略,以进一步提升模型在尾部类的性能,设被校准后的模型输出logits为z
cl
,计算公式如下:
[0027]zcl
=az⊙
φ
t
(x) bzꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀ
(公式5)
[0028]
其中,az和bz表示可被学习的网络参数,

表示哈达玛积。
[0029]
进一步地,步骤s3中,上述对logits校准策略有效的前提是本地模型对输入的图像数据提取的表征信息足够好,若客户端本地模型对输入的图像数据的特征提取因为长尾分布而受到严重影响,那么仅仅只对输出logits做校准是不够的,因此,我们需要更新特征提取器来进一步提升模型性能,我们在服务器端利用额外的平衡有标签的图像,构成的平衡有标签数据集在全局模型w上进行微调,得到微调模型因为的数据分布是平衡的,所以微调模型可以获得无偏的特征提取器,然后,我们可以获得对于输入的图像数据为x的微调模型输出logits为其中z
ft
表示微调模型对x的输出,表示微调模型的网络。
[0030]
进一步地,微调模型其中,η表示学习率,表示损失函数,表示求导。
[0031]
进一步地,步骤s3中,z
cl
和z
ft
是从两个不同的层面去校准教师模型,z
cl
是对教师模型输出logits层面进行校准,其模型的特征提取器被固定,然而z
ft
是对特征提取器微调的结果,以此提升模型特征提取能力,为充分结合二者优势,我们提出通过一个校准门控网络对z
cl
和z
ft
做权衡,校准门控网络以集成特征作为输入,经由一个非线性层输出权重,使得每个样本根据自身的特征不同而获得不同的权重,权重计算公式如下:
[0032]
σ=sigmoid(u
t
v)
ꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀ
(公式6)
[0033]
其中,表示集成特征,表示第k个客户端的特征提取器,u∈rd表示可被学习的网络参数,rd表示d维向量,因此,通过校准门控网络的最终校准模型输出logits为z

,计算公式如下:
[0034]z′
=σz
cl
(1-σ)z
ft
ꢀꢀꢀꢀꢀꢀꢀ
(公式7)
[0035]
其中σ∈(0,1)用于权衡z
cl
和z
ft
两个模型输出logits。
[0036]
进一步地,集成校准的整个过程中,所有可被学习的参数都通过在上的交叉熵损失进行更新,损失函数如下:
[0037][0038]
其中,c表示类别数,yj表示输入图像数据的真实标签,j表示y中第j维的值,exp(
·
)表示以自然常数e为底的指数函数,z
′j表示最终校准z

中第j维的值,z
′i表示最终校准z

中第i维的值,z

是一个c维向量,z
′j和z
′i分别表示其中一维的值。
[0039]
进一步地,步骤s4中,通过知识蒸馏,将教师模型的无偏知识传递给学生模型,具体地,为更好地将教师模型的无偏知识教给学生模型,我们使用有标签数据训练和无标签数据蒸馏结合的方式训练学生模型,其损失函数如下:
[0040]
l

=(1-λ)l
ce
λl
kl
ꢀꢀꢀꢀꢀꢀ
(公式9)
[0041]
其中,l
ce
表示学生模型的模型输出logits和图像真实标签ground-truth之间的交叉熵损失,l
kl
表示教师模型和学生模型之间模型输出logits的相对熵(kl,kullback-leibler)散度,通过平衡有标签数据集计算l
ce
,并使用另一个无标签图像构成的无标签数据集计算l
kl
以进一步提高知识蒸馏性能,λ∈[0,1]表示超参数,对l
ce
和l
kl
进行权衡。
[0042]
本发明的优势和有益效果在于:
[0043]
本发明研究了联邦学习中异构数据和长尾分布的联合问题,充分利用客户端本地模型的多样性来处理异构数据问题,并提出了一种新的模型校准策略和门控网络来有效解决长尾问题,进一步提升了联邦学习下的模型性能。
附图说明
[0044]
图1是本发明的方法流程图。
[0045]
图2是本发明中客户端数据分布图。
具体实施方式
[0046]
以下结合附图对本发明的具体实施方式进行详细说明。应当理解的是,此处所描述的具体实施方式仅用于说明和解释本发明,并不用于限制本发明。
[0047]
如图1所示,一种面向长尾异构数据的联邦学习方法,包括以下步骤:
[0048]
步骤一、准备数据集,初始化网络,并将数据集分给各个客户端,进行模型更新。
[0049]
步骤1.1、使用的数据集为cifar-10、cifar-100和imagenet_lt。
[0050]
cifar-10数据集共有60000张彩色图像,这些图像尺寸为32*32,分为10个类:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。每类6000张图,这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。另外,在cifar-10中每类随机挑选出100张图片构成额外的平衡数据集使用cifar-100作为无标签数据进行知识蒸馏。
[0051]
cifar-100有100个类,每个类包含600个图像。每类各有500个训练图像和100个测试图像。cifar-100中的100个类被分成20个超类。每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类)。在cifar-100中每类随机挑选出10张图片构成额外的平衡数据集使用下采样的imagenet(图像尺寸32*32)作为无标签数据进行知识蒸馏。
[0052]
imagenet-lt是大型图像分类数据集,是imagenet的长尾版本,它通过对服从
pareto分布的子集根据α=6进行采样。它包含1000个类别的115800张图像,最大和最小类别分别包含1280和5张图像。我们从平衡验证数据中获得数据集使用过采样的cifar100(图像尺寸224*224)作为无标签数据进行知识蒸馏。
[0053]
上述三个数据集根据迪利克雷分布中异构程度η=0.1分给不同客户端,作为其本地数据,在cifar-10上的数据分布图如图2所示。
[0054]
步骤1.2、搭建联邦学习环境并初始化网络。
[0055]
在cifar-10-lt和cifar-100-lt上使用resnet-8网络进行训练,在imagenet-lt使用resnet-50网络。我们的所有实验均由pytorch在两个nvidia geforce rtx 3080gpu上运行。一般情况下,我们设计20个客户端,共训练200轮,每轮选择40%的客户端参与联邦训练。对于客户端训练,批大小设置为128,学习率为0.1,sgd作为优化器。对于服务器端的全局模型训练,我们将校准epoch设置为100,蒸馏epoch设置为100,并且使用学习率为0.001的adam优化器进行知识蒸馏。
[0056]
步骤1.3、客户端模型更新。
[0057]
服务器端初始化全局模型参数w,随机选择参与本轮训练的客户端集合s并将模型参数广播给参与本轮训练的客户端集合s。s中的每个客户端均利用收到的全局模型参数w和本地的数据执行随机梯度下降(sgd)以更新它们的模型,设客户端k更新得到的模型参数为wk。待更新之后,各个客户端将其更新的模型参数发还给服务器端。
[0058]
步骤二、服务器端得到教师模型和学生模型,具体流程包括以下子步骤:
[0059]
步骤2.1、首先服务器端对收到的模型参数进行平均加权,得到学生模型,计算公式如下:
[0060][0061]
φs(x)=φw(x)
ꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀ
(公式2)
[0062]
步骤2.2、然后服务器端对从客户端得到的模型参数进行加权聚合,得到教师模型,计算公式如下:
[0063][0064]
其中,ek是赋给客户端k的权重。
[0065]
步骤三、服务器端对步骤d中得到的教师模型进行校准。
[0066]
由于全局数据呈现长尾分布,故各个客户端得到的模型也会偏向于头部类而在尾部类上表现很差,那么由各个客户端模型加权得到的教师模型亦是偏向于头部类而忽视尾部类。由于教师模型偏向于头部类,其包含的知识是极具偏向性的,这样会导致教出来的学生模型也具有偏向性,会严重损害模型性能。所以我们要对教师模型进行校准,获得无偏的教师模型,然后将无偏向性的知识教给学生模型。具体方法包含以下子步骤:
[0067]
步骤3.1、由于本地模型是在具有不同分布的本地数据上进行训练的,每个本地模型在尾部类上的表现可能不同,因此我们要为在尾部类表现更好的本地模型分配更高的权重。然而,服务器不知道哪些类是尾部类并且哪些本地模型在上面表现良好,因此,我们不是给每个客户端一个固定的权重,相反我们提出了基于客户端的权重分配策略,以此来计算对每个客户端本地模型的权重ek,并将ek归一化使其总和等于1,即为最终权重。计算公式如下:
[0068][0069]
其中,ae∈rc和be是可被学习的参数。基于客户端的校准就像自注意力机制一样,根据原始logits对本地模型计算权重,然后权重乘回原始logits。
[0070]
步骤3.2、若没有一个本地模型可以很好地处理尾类,加权集成得到的教师模型可能仍偏向于头部类。为解决这个问题,我们提出基于类的logits校准策略以进一步提升模型在尾部类的性能。设被校准后的模型输出logits为z
cl
,计算公式如下:
[0071]zcl
=az⊙
φ
t
(x) bzꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀ
(公式5)
[0072]
其中,az和bz是可被学习的网络参数,

表示哈达玛积。
[0073]
步骤3.3、上述对logits校准策略有效的前提是本地模型对输入数据提取的表征信息足够好,如果本地模型对数据的特征提取因为长尾分布而受到严重影响,那么仅仅只对logits做校准是不够的。因此,我们需要更新特征提取器来进一步提升模型性能。我们在服务器端利用额外的平衡有标签数据在全局模型w上进行微调,得到微调模型因为是平衡的,所以微调模型可以获得无偏的特征提取器。然后,我们可以获得对于输入x的微调logits为
[0074]
步骤3.4、通过上述步骤,可以看出z
cl
和z
ft
是从两个不同的层面去校准教师模型。z
cl
是对教师模型输出logits层面进行校准,其模型的特征提取器被固定,然而z
ft
是对特征提取器微调的结果,以此提升模型特征提取能力。为充分结合二者优势,我们提出了一个校准门控网络来对z
cl
和z
ft
做权衡。该门控网络以集成特征作为输入,经由一个非线性层输出权重,使得每个样本根据自身的特征不同而获得不同的权重。权重计算公式如下:
[0075]
σ=sigmoid(u
t
v)
ꢀꢀꢀꢀꢀꢀꢀ
(公式6)
[0076]
其中,是集成特征,u∈rd是可被学习的网络参数。因此,通过校准门控网络的最终校准logits为z

,计算公式如下:
[0077]z′
=σz
cl
(1-σ)z
ft
ꢀꢀꢀꢀꢀꢀꢀꢀ
(公式7)
[0078]
其中σ∈(0,1)用于权衡两个logits。
[0079]
步骤3.5、集成校准的整个过程中所有可被学习的参数都通过在上的交叉熵损失进行更新,损失函数如下:
[0080][0081]
步骤四、使用知识蒸馏将教师模型的无偏知识传递给学生模型。
[0082]
为了更好地将教师模型(即校准集成模型)的无偏知识教给学生模型(即全局模型),我们使用有标签数据训练和无标签数据蒸馏结合的方式训练学生模型,其损失函数包含两个部分:(1)l
ce
是学生模型的logits和ground-truth之间的交叉熵损失;(2)l
kl
是教师模型和学生模型之间logits的kullback-leibler(kl)散度。我们使用来计算l
ce
,并使用另一个无标签数据集来计算l
kl
以进一步提高知识蒸馏性能。最终损失函数由超参数λ∈[0,1]权衡:
[0083]
l

=(1-λ)l
ce
λl
kl
ꢀꢀꢀꢀꢀꢀꢀꢀ
(公式9)
[0084]
表1为本发明与其他几种联邦学习方法在cifar-10-lt和cifar-100-lt数据集上不平衡比例为100、50和10的精度(%)比对结果。表中加粗结果为各指标的最优结果。
[0085]
从表1的结果中可以看出,本发明方法可以解决在联邦学习中长尾分布和异构数据的联合问题,本发明方法在所有不平衡程度上均实现了最高的测试精度。
[0086]
表1
[0087][0088]
表2为本发明与几种联邦学习方法在imagenet-lt数据集上的精度(%)对比结果。表中加粗结果为各指标的最优结果。
[0089]
表2中分别对比了几种方法在三种分类上的精度:头部类(样本数量超过100)、中间类(样本数量在20与100之间)和尾部类(样本数量少于20)。与其他方法相比,本发明方法在取得了最好的结果。同时,本发明方法在尾部类上的准确率达到了15.91%,本发明方法解决了在联邦学习中长尾分布和异构数据的联合问题,在提升模型总体性能的同时亦极大提升了模型对尾部类的性能。
[0090]
表2
international conference on learning representations.),通过对高质量的全局模型进行采样并通过贝叶斯模型对其进行组合,从贝叶斯推理的角度出发实现了强大的聚合;
[0100]
fed-focal loss对应sarkar,d等人提出的方法(sarkar,d.;narang,a.;and rai,s.2020.fed-focal loss for imbalanced data classification in federated learning.arxiv preprint arxiv:2011.06283.);
[0101]
ratio loss对应wang,l等人提出的方法(wang,l.;xu,s.;wang,x.;and zhu,q.2021a.addressing class imbalance in federated learning.in aaai conference on artificial intelligence,10165

10173.),其实现了联邦学习中数据不平衡问题不透明化的监测,并且提出了一个全新的损失函数ratio loss去减弱不平衡带来的影响;
[0102]
crt、τ-norm、lws对应kang,b等人提出的方法(kang,b.;xie,s.;rohrbach,m.;yan,z.;gordo,a.;feng,j.;and kalantidis,y.2020.decoupling representation and classifier for long-tailed recognition.in international conference on learning representations.),表明数据不均衡并不影响学习输入数据高质量的表征,作者表明仅调整分类器也可以实现强大的长尾识别能力。
[0103]
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明实施例技术方案的范围。
再多了解一些

本文用于企业家、创业者技术爱好者查询,结果仅供参考。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献