一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

一种机器学习方法、装置和系统与流程

2022-06-01 14:56:36 来源:中国专利 TAG:


1.本技术涉及机器学习技术领域,尤其涉及一种机器学习方法、装置和系统。


背景技术:

2.随着人工智能(artificial intelligence,ai)技术的发展,神经网络与通信系统相结合的研究受到广泛关注。比如,在智能通信、物联网、车联网、智慧城市等场景中,都可以通过机器学习来提升网络的灵活性。
3.示例性的,以智能通信为例。通过机器学习,计算节点(如基站)可以根据各个本地节点(如手机)上传的数据,结合其中设置的神经网络模型进行机器学习。基站可以将通过机器学习获取的符合预设条件的模型下发给各个手机,以使得各个手机能够根据该模型,对各自的通信过程进行调整,实现智能通信。
4.可以看到,在基于神经网络与通信系统的结合,并进行机器学习的过程中,各个本地节点需要将其收集的数据全部上传给计算节点,以使得计算节点能够根据这些数据进行训练。根据场景的不同,被传输的数据量可能非常大。由于各个本地节点都需要将数据传输给计算节点,这就使得本地节点和计算节点之间的数据传输压力较大。此外,由于数据直接被本地节点传输给中心节点,其中可能会包括与用户隐私信息相关的内容,这也会使得信息隐私信息的不安全。


技术实现要素:

5.本技术实施例提供一种机器学习方法、装置和系统,可以显著降低机器学习过程中数据传输的数量,由此降低数据传输在整个学习过程中的耗时占比,从而有效提高机器学习的效率。
6.为了达到上述目的,本技术实施例采用如下技术方案:
7.第一方面,提供一种机器学习方法,方法应用于子节点,子节点中设置有二值化神经网络模型bnn,方法包括:子节点对采集获取的本地数据集,进行基于bnn的机器学习,获取与本地数据集对应的本地模型参数。子节点向中心节点发送第一消息,第一消息包括本地模型参数。
8.基于该方案,提供了一种将bnn与分布式的机器学习架构结合的方案。在该示例中,可以在本地进行基于bnn的二值化机器学习,由此获取的本地神经网络模型的参数可以为二值化的参数。将二值化的模型参数发送给中心节点,相比于直接将高精度的神经网络模型或者模型参数发送给中心节点,能够显著地降低对于数据传输带宽的要求,同时由于要传输数据显著变少,因此传输耗时也相应降低。可以理解的是,在进行数据传输过程中,子节点不会进行机器学习,因此,降低数据传输耗时就能够显著增加在整个学习过程中,进行机器学习的时间占比,由此提升学习效率。
9.在一种可能的设计中,第一消息中包括的本地模型参数是二值化的本地模型参数。基于该方案,给出了本地模型参数在被传输过程中的形式限定。在该示例中,被传输的
本地模型参数为二值化的参数,也就是包括 1和-1两种参数。显而易见的,每个参数对应到1比特带宽即可。而相较于传统的fl架构下的数据传输,由于要进行基于高精度的数据传输,而一个高精度元素对应到的传输带宽可能是16比特或更多,这样要传输一个包括多个高精度元素的模型参数矩阵时,就需要耗费更多的传输资源。因此,通过该示例提供的方案,能够有效地降低数据传输资源的需求,传输时间更短,由此提升系统的学习效率。
10.在一种可能的设计中,该方法还包括:子节点从中心节点接收融合参数,融合参数是中心节点根据本地模型参数融合获取的。子节点根据融合参数对本地模型参数进行更新,以获取更新后的本地模型参数。基于该方案,给出了一种子节点通过接收信息实现模型更新的方法。可以理解的是,在分布式的学习架构中,在中心节点可以进行数据的汇总处理,比如,可以获取各个子节点分别传输的模型参数,并对这些模型参数进行融合,由此获取与各个子节点对应情况都适配的具有较高适应性的模型参数(比如上述融合参数)。子节点可以从中心节点获取该融合参数,这样,就可以根据该融合参数,结合本地参数进行本地融合,由此实现对于本地神经网络模型的更新。
11.在一种可能的设计中,融合参数是二值化的模型参数,或者,融合参数是高精度的模型参数。基于该方案,给出了融合参数的不同形式。比如,在一些实现方式中,该融合参数可以是二值化的模型参数。又如,在另一些实现方式中,该融合参数可以是高精度的模型参数。可以理解的是,在本技术中,子节点可以将二值化的模型参数传输给中心节点。中心节点可以根据多个二值化的模型参数进行融合(比如根据权值进行平均加权),由此获取对应的融合参数。以融合处理为加权平均为例,显而易见的,通过融合处理之后的模型参数并非只包括 1以及-1的元素,也就是说,在加权平均之后获取的融合模型应当为高精度的模型参数。在本技术中,可以根据不同场景,选择在下发(或者下传)融合参数时,采用直接获取的高精度参数进行下传,以便于子节点能够根据该高精度融合参数更加准确地更新本地模型。或者采用对高精度参数进行二值化处理后的二值化融合参数进行下传,以便于数据下传的速率更快,子节点也可以更快地完成本地模型的更新。对于子节点而言,可以在接收二值化的融合参数时,根据与二值化融合参数对应的本地融合方法进行本地融合,在接收高精度的融合参数时,根据与高精度融合参数对应的本地融合方法进行本地融合。其具体执行方法请参考实施例中的说明,此处不再赘述。
12.在一种可能的设计中,第一消息还包括:与本地模型参数对应的准确度信息。其中,准确度信息是子节点根据本地模型参数,以及测试数据集,校验获取的。基于该方案,提供了一种第一消息的具体内容。在该示例中,子节点在完成一轮本地学习之后,可以将与该学习结果对应的准确度信息发送给中心节点。以便与中心节点可以根据各个子节点上报的准确度信息确定系统准确度,进而由此确定下发二值化融合参数还是高精度融合参数。需要说明的是,在本技术的另一些实现方式中,准确度信息还可以是子节点根据本地模型参数,以及验证数据集,校验获取的。
13.在一种可能的设计中,该方法还包括:子节点根据更新后的本地模型参数,基于bnn继续进行机器学习。基于该方案,提供了一种子节点根据融合参数更新本地模型之后的方法示例。在该示例中,子节点可以在更新本地模型参数之后,基于现有的数据集,或者结合新增的数据集对该本地模型继续进行第二轮或后续轮次的学习,并重复上述示例中的方法,直到学习结果收敛为止,完成机器学习。需要说明的是,在本技术的另一些实现方式中,
无论学习结果是否收敛,更新后的本地模型都可以用于对子节点当前的业务进行指导,比如预测数据走向等。
14.第二方面,提供一种机器学习方法,方法应用于中心节点,方法包括:中心节点接收分别来自n个子节点的n个第一消息,第一消息包括本地模型参数,本地模型参数是二值化的本地模型参数。n为大于或等于1的整数。中心节点对n个第一消息包括的本地模型参数进行融合,获取融合参数。中心节点向m个子节点发送第二消息,第二消息包括融合参数,m为大于或等于1的正整数。
15.基于该方案,中心节点可以接收来自多个子节点的本地模型参数,并基于这些本地模型参数进行融合,由此获取具有更强适应性的融合模型。中心节点可以将该融合模型下发给各个子节点,以便于子节点可以根据该融合模型进行本地融合完成一轮学习。相较于现有的fl框架中的分布式架构,在本示例中,中心节点接收到的来自子节点的模型参数可以是二值化的模型参数。可以理解的是,二值化的模型参数的数据量要显著的小于普通的模型参数(如高精度的模型参数)的数据量,因此该上传的过程更加高效。中心节点将各个本地模型参数融合后,获取的融合参数,能够适配各个子节点对应的数据集类型,因此具有更加准确和适配性更强的特征。需要说明的是,在本技术是一些实现方式中,n个子节点中有可能存在一些参考用子节点,这些子节点可以用于提供本地模型参数,但是并不需要来自中心节点的融合参数。而需要根据融合参数对本地模型进行更新的节点也可能不包括在n个子节点中。因此,在一些实现方式中,m个子节点中也可以包括n个子节点中没有的子节点,或者,m个子节点可以是n个子节点中第一部分。具体的m个子节点的确定,可以根据实际实施过程中灵活配置。
16.在一种可能的设计中,中心节点对n个本地模型参数进行融合,获取融合参数,包括:中心节点对n个本地模型参数,进行加权平均,获取融合参数。基于该方案,提供了一种获取融合参数的方案。在该示例中,中心节点可以通过简单的加权平均,处理n个本地参数模型。其中,加权平均中的权重可以根据本地参数模型在进行本地训练过程中,输入数据集的大小确定。中心节点可以从各个子节点获取进行本轮学习过程中使用数据集的大小,也可以从其他节点中获取各个子节点在进行本地学习过程中所使用的数据集的大小。当然,在本技术的其他一些实现方式中,中心节点还可以结合其他因素,调整权重。比如,对于一些使用较为频繁的子节点,其权重可以适当增加,而对于一些神经网络模型使用较少的子节点,其对应权重可以适当减小。
17.在一种可能的设计中,第二消息包括的融合参数是高精度的融合参数,或者,第二消息包括的融合参数是二值化的融合参数。基于该方案,提供了一种中心节点下发融合参数的方案示例。在该示例中,中心节点可以将融合处理之后获取的高精度参数直接通过第二消息下发给子节点。中心节点也可以将融合处理之后的高精度参数,通过二值化处理再通过第二消息下发给子节点。可以理解的是,在需要提升数据传输速率时,可以通过下发二值化融合参数实现,在需要提升准确度时,可以通过下发高精度融合参数实现。
18.在一种可能的设计中,在中心节点向m个子节点发送第二消息之前,方法还包括:中心节点根据n个第一消息,确定系统准确度信息,中心节点根据系统准确度信息,确定第一消息中包括的融合参数为高精度的融合参数或者二值化的融合参数。基于该方案,提供了一种中心节点调整下发高精度融合参数和二值化融合参数的机制。在该示例中,中心节
点可以根据系统准确度信息,确定下发高精度融合参数或者二值化融合参数。比如,在系统准确度较低时,可以下发二值化融合参数,又如,在系统准确度较高时,可以下发高精度融合参数。其中,系统准确度可以根据各个子节点的准确度确定,也可以是中心节点根据各个子节点的模型参数自发地进行校验确定的。
19.在一种可能的设计中,第一消息还包括:准确度信息。准确度信息与第一消息中包括的本地模型参数在对应子节点处校验获取的准确度对应。中心节点根据n个第一消息,确定系统准确度信息,包括:中心节点根据n个消息中包括的准确度信息,确定系统准确度信息。基于该方案,提供了一种中心节点确定系统准确度信息的方法。在该示例中,各个子节点可以将本轮学习获取的模型参数校验的准确度发送给中心节点,中心节点就可以根据各个子节点上传的准确度,确定系统准确度,进而据此调整下发融合参数的形式。
20.在一种可能的设计中,在系统准确度信息小于或等于第一阈值时,中心节点确定融合参数为二值化的融合参数。在系统准确度信息大于或等于第二阈值时,中心节点确定融合参数为高精度的融合参数。基于该方案,提供了一种具体的中心节点确定下发融合参数的形式的示例。在该示例中,中心节点在确定系统准确度小于或等于第一阈值时,则认为当前系统中的学习处于初步阶段,模型参数还有大量的调整空间,因此不需要进行精度较高的数据传输,此时应当数据传输效率的提升为主。因此,在系统准确度小于或等于第一阈值时,则中心节点可以下发二值化的融合参数,由此提升数据传输速率。对应的,中心节点在确定系统准确度大于或等于第一阈值时,则认为当前系统中的学习已经接近收敛,模型参数的调整空间较小,因此需要进行精度较高的数据传输。因此,在系统准确度大于或等于第一阈值时,则中心节点可以下发高精度的融合参数,由此提升模型参数的准确度。其中,第一阈值与第二阈值可以是预先设置的,在不同的实现方式中,第一阈值和第二阈值可以相同,也可以不同。
21.在一种可能的设计中,在中心节点向m个子节点发送第二消息,包括:在迭代轮数小于或等于第三阈值时,中心节点向m个子节点发送包括二值化的融合参数的第二消息。在迭代轮数大于或等于第四阈值时,中心节点向m个子节点发送包括高精度的融合参数的第二消息。基于该方案,提供了又一种中心节点确定下发融合参数的形式的机制。在该示例中,中心节点可以根据迭代轮数确定下发融合参数的形式。比如,在迭代轮数较少时,即小于或等于第三阈值时,则中心节点可以认为当前状态下应当以数据传输效率的提升为主,因此可以选择下发二值化的融合参数,以提升数据传输速率。对应的,在迭代轮数较多时,即大于或等于第四阈值时,则中心节点可以认为当前状态下应当以准确度为主,因此可以选择下发高精度的融合参数,以提升本地融合过程中的准确度。
22.在一种可能的设计中,中心节点通过广播,向m个子节点发送第二消息。基于该方案,提供了一种中心节点下发第二消息的方式。在该示例中,中心节点可以通过广播的形式进行第二消息的下发,而不需对各个子节点分别进行下发。可以理解的是对于各个子节点的数据下发内容相近,因此,可以通过广播的形式,同时将该数据下发给各个子节点。同时,由于传输的是二值化的融合参数或者高精度的融合参数,因此,广播的传输形式也不会对信息安全性有影响。
23.第三方面,提供一种机器学习装置,该装置可以应用于子节点,子节点中设置有二值化神经网络模型bnn,该装置包括:获取单元,用于对采集获取的本地数据集,进行基于
bnn的机器学习,获取与本地数据集对应的本地模型参数。发送单元,用于向中心节点发送第一消息,第一消息包括本地模型参数。
24.在一种可能的设计中,第一消息中包括的本地模型参数是二值化的本地模型参数。
25.在一种可能的设计中,该装置还包括:接收单元,用于从中心节点接收融合参数,融合参数是中心节点根据本地模型参数融合获取的。融合单元,用于对根据融合参数以及本地模型参数进行融合,以获取更新后的本地模型参数。
26.在一种可能的设计中,融合参数是二值化的模型参数,或者,融合参数是高精度的模型参数。
27.在一种可能的设计中,第一消息还包括:与本地模型参数对应的准确度信息。获取单元,还用于根据本地模型参数,以及测试数据集,校验获取准确度信息。
28.在一种可能的设计中,该装置还包括:学习单元,用于根据更新后的本地模型参数,基于bnn继续进行机器学习。
29.第四方面,提供一种机器学习装置,该装置应用于中心节点,装置包括:接收单元,用于接收分别来自n个子节点的n个第一消息,第一消息包括本地模型参数,本地模型参数是二值化的本地模型参数。n为大于或等于1的整数。融合单元,用于对n个第一消息包括的本地模型参数进行融合,获取融合参数。发送单元,用于向m个子节点发送第二消息,第二消息包括融合参数,m为大于或等于1的正整数。
30.在一种可能的设计中,融合单元,具体用于对n个本地模型参数,进行加权平均,获取融合参数。
31.在一种可能的设计中,第二消息包括的融合参数是高精度的融合参数,或者,第二消息包括的融合参数是二值化的融合参数。
32.在一种可能的设计中,该装置还包括:确定单元,用于根据n个第一消息,确定系统准确度信息,中心节点根据系统准确度信息,确定第一消息中包括的融合参数为高精度的融合参数或者二值化的融合参数。
33.在一种可能的设计中,第一消息还包括:准确度信息。准确度信息与第一消息中包括的本地模型参数在对应子节点处校验获取的准确度对应。确定单元,具体用于根据n个消息中包括的准确度信息,确定系统准确度信息。
34.在一种可能的设计中,确定单元,用于在系统准确度信息小于或等于第一阈值时,确定融合参数为二值化的融合参数。确定单元,还用于在系统准确度信息大于或等于第二阈值时,确定融合参数为高精度的融合参数。
35.在一种可能的设计中,发送单元,用于在迭代轮数小于或等于第三阈值时,向m个子节点发送包括二值化的融合参数的第二消息。发送单元,还用于在迭代轮数大于或等于第四阈值时,向m个子节点发送包括高精度的融合参数的第二消息。
36.在一种可能的设计中,发送单元,具体用于通过广播,向m个子节点发送第二消息。
37.第五方面,提供一种子节点,该子节点可以包括一个或多个处理器和一个或多个存储器。一个或多个存储器与一个或多个处理器耦合,一个或多个存储器存储有计算机指令。当一个或多个处理器执行计算机指令时,使得子节点执行如第一方面及其可能的设计中任一项所述的机器学习方法。
38.第六方面,提供一种中心节点,该中心节点可以包括一个或多个处理器和一个或多个存储器。一个或多个存储器与一个或多个处理器耦合,一个或多个存储器存储有计算机指令。当一个或多个处理器执行计算机指令时,使得中心节点执行如第二方面及其可能的设计中任一项所述的机器学习方法。
39.第七方面,提供一种机器学习系统,机器学习系统包括一个或多个第五方面提供的子节点,以及一个或多个如第六方面提供的中心节点。
40.第八方面,提供一种芯片系统,芯片系统包括接口电路和处理器;接口电路和处理器通过线路互联;接口电路用于从存储器接收信号,并向处理器发送信号,信号包括存储器中存储的计算机指令;当处理器执行计算机指令时,芯片系统执行如上述第一方面以及各种可能的设计中任一种所述的机器学习方法,或者,执行如上述第二方面以及各种可能的设计中任一种所述的机器学习方法。
41.第九方面,提供一种计算机可读存储介质,计算机可读存储介质包括计算机指令,当计算机指令运行时,执行如上述第一方面以及各种可能的设计中任一种所述的机器学习方法,或者,执行如上述第二方面以及各种可能的设计中任一种所述的机器学习方法。
42.第十方面,提供一种计算机程序产品,计算机程序产品中包括指令,当计算机程序产品在计算机上运行时,使得计算机可以根据指令执行如上述第一方面以及各种可能的设计中任一种所述的机器学习方法,或者,执行如上述第二方面以及各种可能的设计中任一种所述的机器学习方法。
43.应当理解的是,上述第三方面,第四方面,第五方面,第六方面,第七方面,第八方面,第九方面以及第十方面提供的技术方案,其技术特征均可对应到第一方面及其可能的设计中,或者第二方面及其可能的设计中提供的机器学习方法,因此能够达到的有益效果类似,此处不再赘述。
附图说明
44.图1为一种机器学习在通信过程中的实现示意图;
45.图2为一种fl架构的工作示意图;
46.图3为一种bnn与基于高精度参数的普通神经网络的对比示意图;
47.图4为本技术实施例提供的一种机器学习系统的组成;
48.图5为本技术实施例提供的又一种机器学习系统的组成;
49.图6为本技术实施例提供的一种机器学习系统的工作逻辑示意图;
50.图7为本技术实施例提供的又一种机器学习系统的工作逻辑示意图;
51.图8为本技术实施例提供的又一种机器学习系统的工作逻辑示意图;
52.图9为本技术实施例提供的一种机器学习方法的逻辑示意图;
53.图10为本技术实施例提供的一种仿真结果的对比示意图;
54.图11为本技术实施例提供的又一种仿真结果的对比示意图;
55.图12为本技术实施例提供的又一种仿真结果的对比示意图;
56.图13为本技术实施例提供的一种机器学习装置的组成示意图;
57.图14为本技术实施例提供的又一种机器学习装置的组成示意图;
58.图15为本技术实施例提供的一种子节点的组成示意图;
59.图16为本技术实施例提供的一种芯片系统的组成示意图;
60.图17为本技术实施例提供的一种中心节点的组成示意图;
61.图18为本技术实施例提供的又一种芯片系统的组成示意图。
具体实施方式
62.在本技术实施例中,“示例性的”或者“例如”等词用于表示作例子、例证或说明。本技术实施例中被描述为“示例性的”或者“例如”的任何实施例或设计方案不应被解释为比其它实施例或设计方案更优选或更具优势。确切而言,使用“示例性的”或者“例如”等词旨在以具体方式呈现相关概念。
63.在本技术的实施例中,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。在本技术的描述中,除非另有说明,“多个”的含义是两个或两个以上。
64.本技术中术语“至少一个”的含义是指一个或多个,本技术中术语“多个”的含义是指两个或两个以上,例如,多个第二报文是指两个或两个以上的第二报文。
65.应理解,在本文中对各种所述示例的描述中所使用的术语只是为了描述特定示例,而并非旨在进行限制。如在对各种所述示例的描述和所附权利要求书中所使用的那样,单数形式“一个(“a”,“an”)”和“该”旨在也包括复数形式,除非上下文另外明确地指示。
66.还应理解,本文中所使用的术语“和/或”是指并且涵盖相关联的所列出的项目中的一个或多个项目的任何和全部可能的组合。术语“和/或”,是一种描述关联对象的关联关系,表示可以存在三种关系,例如,a和/或b,可以表示:单独存在a,同时存在a和b,单独存在b这三种情况。另外,本技术中的字符“/”,一般表示前后关联对象是一种“或”的关系。
67.还应理解,在本技术的各个实施例中,各个过程的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本技术实施例的实施过程构成任何限定。
68.应理解,根据a确定b并不意味着仅仅根据a确定b,还可以根据a和/或其它信息确定b。
69.还应理解,术语“包括”(也称“includes”、“including”、“comprises”和/或“comprising”)当在本说明书中使用时指定存在所陈述的特征、整数、步骤、操作、元素、和/或部件,但是并不排除存在或添加一个或多个其他特征、整数、步骤、操作、元素、部件、和/或其分组。
70.还应理解,术语“如果”可被解释为意指“当...时”(“when”或“upon”)或“响应于确定”或“响应于检测到”。类似地,根据上下文,短语“如果确定...”或“如果检测到[所陈述的条件或事件]”可被解释为意指“在确定...时”或“响应于确定...”或“在检测到[所陈述的条件或事件]时”或“响应于检测到[所陈述的条件或事件]”。
[0071]
应理解,说明书通篇中提到的“一个实施例”、“一实施例”、“一种可能的实现方式”意味着与实施例或实现方式有关的特定特征、结构或特性包括在本技术的至少一个实施例中。因此,在整个说明书各处出现的“在一个实施例中”或“在一实施例中”、“一种可能的实现方式”未必一定指相同的实施例。此外,这些特定的特征、结构或特性可以任意适合的方
式结合在一个或多个实施例中。
[0072]
还应理解,本技术实施例中提到的“连接”,可以是直接连接,也可以是间接连接,可以是有线连接,也可以是无线连接,也就是说,本技术实施例对设备之间的连接方式不作限定。
[0073]
以下对本技术实施例提供的方案进行详细说明。
[0074]
由于基于神经网络的机器学习与通信系统的耦合可以显著提升通信过程的灵活性,因此,神经网络在通信过程中的使用机制对于通信效果影响显著。
[0075]
目前的方案中,本地节点可以采集相关数据,并将这些数据分别上传给计算节点,以便计算节点基于这些数据进行学习,由此获取对应的训练模型。计算节点可以将训练模型下发给各个本地节点,以使得各个本地节点可以根据训练模型,对其在通信过程中的工作进行预测指导。
[0076]
示例性的,结合图1,为一种机器学习在通信过程中的实现示例。其中,以3个本地节点通过计算节点进行机器学习为例。如图1所示,本地节点1可以将采集获取的数据组成的数据集1上传给计算节点。类似的,本地节点2可以将采集获取的数据组成的数据集2上传给计算节点。本地节点3可以将采集获取的数据组成的数据集3上传给计算节点。计算节点可以根据这些数据集(如数据集1-数据集3)进行机器学习。比如,计算节点中可以预设有基础神经网络模型,根据数据集1-数据集3对基础神经网络模型进行迭代学习,优化基础神经网络模型的各个模型参数(如权重和偏置等),获取迭代收敛的模型参数,由此完成一轮机器学习。此后,计算节点可以将迭代收敛的模型参数下发给各个本地节点。比如,将模型参数发送给本地节点1,本地节点2以及本地节点3。需要说明的是,在本示例中,本地节点也可预设有与计算节点中基础神经网络模型相同类型的模型。本地节点可以根据接收到的模型参数,对本地维护的模型进行更新,由此获取经过机器学习后的训练模型。
[0077]
这样,本地节点就可以根据该训练模型,对其工作进行预测指导。比如,在涉及边缘计算的过程中,车联网,自动驾驶,以及对用户输入习惯的预测等场景下,都可以根据上述方案,采用对应的训练模型对相应的参数进行判断预测,由此大幅提升本地节点的工作性能。
[0078]
可以看到,如图1所示的方案中,各个本地节点都需要将其采集的数据集分别发送给计算节点。为了使得机器学习的结果足够精确,一般而言,计算节点所需要收集的数据集的数据量是非常大的,这就使得本地节点在向计算节点传输数据集的过程中,会对二者之间的通信链路造成很大的负担。另外,由于数据集是直接被发送给计算节点的,因此,在数据集中包括用户的一些隐私信息时,就会导致隐私信息被直接暴露在通信链路以及计算节点中,由此造成信息隐私性的隐患。
[0079]
为了解决上述问题,目前,可以采用具有分布式架构的联邦学习(federated learning,fl)架构进行机器学习和通信的结耦,降低数据传输量,同时对信息隐私起到适当的保护。
[0080]
在fl架构中,可以设置有中心节点(或称为中心服务器)以及子节点。其中,根据任务和数据分布不同,子节点数可以为几个到几千个不等。对于每一个子节点,本地都可以预设有神经网络模型。各个子节点中的神经网络模型相同。子节点可以获取对应的数据集在神经网络模型中进行学习。在学习过程中,每个子节点都进行若干次迭代(通常为遍历数据
集中所有数据各一次),然后将具有收敛的模型参数的本地模型上传到中心节点。中心节点会将所有发来的本地模型按照各子节点的数据量比例进行加权求均值(该过程也可称为训练模型的融合),由此获取融合后的训练模型。接着,中心节点可以把得到的训练模型下发给所有子节点,以便子节点可以根据融合后的训练模型继续根据新的数据集进行训练,或者,直接用于进行相关场景的计算以及预测。需要说明都是,在不同的实现中,子节点与中心节点之间的训练模型的传输,可以是直接传输训练模型的所有数据,也可以是只对训练模型的参数进行传输即可。
[0081]
示例性的,结合图2,为一种fl架构的工作示意。其中,以该架构中包括3个子节点(如本地节点1,本地节点2以及本地节点3),1个中心节点为例。对于本地节点1,可以采集相关数据,以获取数据集1。在本地节点1中可以存储有与数据集1对应的本地训练模型。本地节点1可以将数据集1输入到本地训练模型中,进行本地训练,由此就可以获取收敛后的本地训练模型参数1(如标识为w1)。类似的,在其他本地节点中,也可进行如上述本地节点1类似的处理,并获取对应的本地训练模型参数2(如标识为w2)和本地训练模型参数3(如标识为w3)。可以理解的是,由于本地训练模型参数的学习获取,与输入的数据集强相关。比如,在输入的数据集不同时,得到的本地训练模型参数也可能不同。3个本地节点可以分别将获取的本地训练模型参数(如w1-w3)发送给中心节点。中心节点可以对获取的w1-w3进行融合,进而得到融合后的训练模型参数(如w0)。中心节点可以将该融合后的训练模型参数分别下发给本地节点1、本地节点2以及本地节点3。各个本地节点就可以根据接收到的w0,对本地训练模型进行更新。
[0082]
可以看到,在fl架构中,数据集在本地就被处理,而不需要被发送给中心节点,这样就可以保证数据集的信息安全得以保证。同时,由于训练模型或者训练模型参数的数据量显著小于数据集本身的数据量,因此,通过传输训练模型,能够有效地降低子节点与中心节点之间的数据传输压力。
[0083]
但是,随着通信技术的发展,fl架构下的训练效率和数据传输量依然不能满足所有场景下的需求。以子节点为手机,中心节点为基站为例。手机和基站之间随时都进行着信息的交互。在fl架构下,虽然只需传输学习训练模型或者训练模型参数,但是,由于一个参数往往需要16比特(bit)或更高的传输带宽,而一组训练模型参数包括多个参数,因此,对于传输带宽的要求依然很高。进一步的,手机在进行训练模型(或训练模型参数)的传输过程中,不会继续进行本地训练,因此,由于训练模型(或训练模型参数)传输时间长就会导致整个fl架构的训练效率较低。
[0084]
为了解决上述问题,本技术实施例提供的方案,能够结合二值化的数据处理方案,以及分布式的神经网络学习方案,达到在提升机器学习效率的同时,降低对于数据传输压力的效果。
[0085]
首先,对本技术所涉及的二值化的数据处理方案进行说明。需要说明的是,本技术中,应用二值化的数据处理方案的神经网络也可称为二值化神经网络(binary neural network,bnn)。
[0086]
需要说明的是,数据在进行传输之前(如在子节点和中心节点之间的传输之前),发出数据的节点需要将数据进行量化才能传输。比如,以子节点将数据发送给中心节点为例。子节点可以将需要发送的数据量化为由0或1组成的序列,然后通过上行数据传输通道
传输该序列。可以理解的是,在数据量化后,其对应的数据与量化前的数据可能不完全对等。量化的过程中,一个需要传输的数据(如称为全精度的数据)对应到量化后的序列位宽越宽,则量化后获取的参数的精度越高。在本技术实施例中,可以将量化位宽较宽的量化数据称为高精度数据,或者高精度参数。在一些实现中,高精度数据或者高精度参数可以是指序列位宽大于或等于32比特的数据。
[0087]
不同于基于高精度参数的神经网络,在bnn中,神经网络的参数由 1和-1的二值化参数组成。在各个子节点中,进行神经网络的学习(或称为训练)过程中,用二值化的参数标识训练模型的参数。因此,在对bnn进行学习时,相较于基于未进行二值化处理的高精度参数的模型学习,能够达到有效减少神经网络用于推演时的计算量、使得学习过程加快收敛的效果。同时,bnn也可以减少储存神经网络参数所需的储存量、进而减少发送整个神经网络所需的通信量。
[0088]
示例性的,图3为一种bnn与基于高精度参数的普通神经网络的对比示意。如图3所示,在普通神经网络中,在进行3次迭代计算的情况下,分别对应的高精度参数可以为w1、w2以及w3。将该过程对应到bnn中,那么,在进行3次迭代计算的情况下,分别对应的二值化参数可以为wb1、wb2以及wb3。对于同样的迭代计算过程,以w1和wb1为例。w1和wb1的对应关系可以为,w1通过二值化转换,即可得到对应的wb1。比如,该二值化转换可以为:对于w1对应计算矩阵中的任意一个元素,如果该元素大于0,则对应wb1矩阵中对应位置的元素记为 1。对应的,如果该元素小于0,则对应wb1矩阵中对应位置的元素记为-1。而对于wb1,可以通过梯度累积获取对应的w1。可以理解的是,在一个典型的bnn学习过程中,可以使用二值化参数进行前向计算和梯度计算,并在对应的高精度参数上累积梯度。当高精度参数上累积了足够大的梯度时,二值化参数就会发生跳变。bnn通过迭代进行多次以上过程,逐步更新参数,并在足够多次迭代后最终收敛。因此,在使用一个学习好的bnn时,只需要使用二值化参数进行推演,最终的输出即为bnn的推演结果。
[0089]
以下结合实际场景,对bnn的学习进行举例说明。以图像分类问题和随机梯度下降方法为例。记bs为每次学习的批数据大小,二值化参数为w
ib
,高精度参数为wi。其中,下角标i表示的是用户index。在进行学习时,每次以不放回的方式从数据集(x1,y1),(x2,y2),

,(xn,yn)中抽取bs个数据,记为(x
′1,y
′1),(x
′2,y
′2),

,(x

bs
,y

bs
),n为数据集中的数据数量。然后,子节点可以计算在该loss计算公式中,l表示使用的神经网络的结构,w
ib
为当前的二值化参数,lossfunc(
·
,
·
)为神经网络的损失函数,loss为最后的损失值。在计算出loss后,子节点进行反向传播算出二值化参数的梯度值,即并将梯度累积到高精度参数上,wi←wi-ηgrad,其中η为学习率,该学习率可以是在进行计算之前预先设置的。最后,如果某个高精度参数在本次迭代中符号改变,则对应的二值化参数符号也翻转(比如,从 1翻转为-1)。在每次上述迭代结束后,从余下的数据集中继续抽取bs个数据(不足则全部抽取),重复上述迭代。直到全部数据均已被抽取过,无法继续抽取数据时,一轮学习结束。在二值神经网络学习过程中,需要多次重复这一过程,直到方法收敛。在本示例中,方法的收敛可以根据loss的计算结果与前1次或前几次的计算结果相对比确定。比如,在本次loss计算后,如果本次loss的结算结果与前3次的计算结果差值均在预设的范围之内,那么认为方法收敛。
[0090]
可以看到,bnn相较于普通的神经网络,具有快速收敛,数据传输量小的特点。但
是,目前尚未有将bnn使用在基于分布式机器学习系统的方案。而如果直接将bnn应用在现有的基于分布式学习的系统(如fl架构)中,由于所有数据传输都是以二值化的形式进行的,因此会导致整个系统的学习准确度过低而无法使用的情况。
[0091]
在本技术实施例提供的机器学习方法中,可以将上述bnn使用在诸如fl等基于分布式机器学习系统的学习框架下,并根据本技术实施例提供的机器学习方法,使得整个fl框架下的机器学习系统的数据传输压力得到显著缓解,同时结合不同场景的需求,提升整个机器学习系统下的学习效率。需要说明的是,在本技术中,各个子节点中的本地训练都可以进行基于二值化的神经网络模型进行本地训练。在不同的使用场景下,该神经网络模型可以灵活选取,比如,该神经网络模型可以具有全连接网络、卷积神经网络等网络结构。
[0092]
本技术实施例提供的机器学习方法,可以应用于包括3g/4g/5g/6g,或者卫星通信等无线通信系统中。
[0093]
其中,无线通信系统通常由小区组成,每个小区包含一个基站(base station,bs),基站向多个移动台(mobile station,ms)提供通信服务。其中基站包含bbu(baseband unit,中文:基带单元)和rru(remote radio unit,中文:远端射频单元)。bbu和rru可以放置在不同的地方,例如:rru拉远,放置于高话务量的区域,bbu放置于中心机房。bbu和rru也可以放置在同一机房。bbu和rru也可以为一个机架下的不同部件。
[0094]
需要说明的是,本发明方案提及的无线通信系统包括但不限于:窄带物联网系统(narrow band-internet of things,nb-iot)、全球移动通信系统(global system for mobile communications,gsm)、增强型数据速率gsm演进系统(enhanced data rate for gsm evolution,edge)、宽带码分多址系统(wideband code division multiple access,wcdma)、码分多址2000系统(code division multiple access,cdma2000)、时分同步码分多址系统(time division-synchronization code division multiple access,td-scdma),长期演进系统(long term evolution,lte)以及下一代5g移动通信系统的三大应用场景embb,urllc和emtc。
[0095]
在本示例中,基站是一种部署在无线接入网中为ms提供无线通信功能的装置。述基站可以包括各种形式的宏基站,微基站(也称为小站),中继站,接入点等。在采用不同的无线接入技术的系统中,具备基站功能的设备的名称可能会有所不同,例如,在lte系统中,称为演进的节点b(evolved nodeb,enb或者enodeb),在第三代(3rd generation,3g)系统中,称为节点b(node b)等。为方便描述,本技术所有实施例中,上述为ms提供无线通信功能的装置统称为网络设备或基站或bs。
[0096]
本发明方案中所涉及到的ms可以包括各种具有无线通信功能的手持设备、车载设备、可穿戴设备、计算设备或连接到无线调制解调器的其它处理设备。所述ms也可以称为终端(terminal),还ms可以是用户单元(subscriber unit)、蜂窝电话(cellular phone)、智能手机(smart phone)、无线数据卡、个人数字助理(personal digital assistant,pda)电脑、平板型电脑、无线调制解调器(modem)、手持设备(handset)、膝上型电脑(laptop computer)、机器类型通信(machine type communication,mtc)终端等。
[0097]
请参考图4,为本技术实施例提供的一种机器学习系统的组成。如图4所示,该机器学习系统中可以包括中心节点,以及多个子节点(如子节点1-子节点n)。其中,子节点也可称为本地节点。中心节点可以与各个子节点通过有线或无线的方式进行通信。比如,中心节
点可以接收各个子节点在上行传输通道中上传的本地训练结果。其中,该本地训练结果可以包括本地训练模型参数,或者本地训练模型本身。又如,中心节点可以将融合后的本地训练模型参数或者本地训练模型本身通过广播或者下行传输通道下发给各个子节点。为了便于对本技术实施例提供的方案进行说明,以下以在子节点和中心节点传输的数据对应于本地训练模型参数,以及融合训练模型参数为例。在不同的场景中,对于本地训练模型参数或者本地训练模型的传输,可以是通过传输各个参数对应的二值化参数实现的,也可以是通过直接传输各个参数对应的高精度参数实现的。需要说明的是,该本地训练模型参数可以包括参数可以为神经网络的权重、偏置等参数中第一项或多项,该本地训练模型参数也可以为它们对应的梯度。
[0098]
示例性的,以中心节点为基站,与该中心节点结耦的子节点为手机为例。如图5所示,该机器学习系统中可以包括一个基站,以及n个手机(如手机1-手机n)。n个手机与基站中可以预先存储有相同的基础训练模型。手机1可以采集对应场景的数据,由此形成对应的数据集1。手机1可以将该数据集1输入到基础训练模型中进行本地训练。可以理解的是,本技术实施例中,手机可以通过bnn进行本地训练。比如,手机1可以将该数据集1输入到基础训练模型中,按照高精度参数的机器学习方法,获取本地模型参数1。该本地模型参数1可以是高精度参数。在一些实现中,该本地模型参数1可以包括高精度的权重以及偏置。手机1可以基于高精度的权重和偏置进行反向推演,由此完成数据集1中一部分数据的学习。接着,手机1可以将对应的高精度的权重和偏置经过二值化转换,获取对应的二值化参数。在对于后续数据的学习过程中,手机1可以基于该二值化参数,进行训练学习。由于二值化参数的数据量显著的小于高精度的数据量,因此手机1可以快速地获取完成本地训练,以获取收敛后的权重和偏置。由于是基于二值化参数进行的本地训练,因此,手机1获取的权重和偏置结果可以为二值化的参数。
[0099]
类似于手机1中的处理,其他手机,如手机2-手机n也可分别进行类似的本地训练,以获得对应的二值化参数。本技术中,可以将各个手机在本地训练获取的二值化参数称为本地参数。比如,手机1经过1轮学习后获取的二值化的权重和偏置可以称为本地参数1。手机2经过1轮学习后获取的二值化的权重和偏置可以称为本地参数2。手机n经过1轮学习后获取的二值化的权重和偏置可以称为本地参数n。
[0100]
各个手机可以将其对应的本地参数分别发送给基站。基站可以将获取的n个本地参数(如本地参数1-本地参数n)进行融合,以获取归一的融合参数。接着,基站可以将该融合参数分别下发给各个手机,手机在接收到融合参数之后可以据此更新本地的基础训练模型,并进行下一轮学习或者直接用于实际场景中的数据预测。
[0101]
结合图4,以下以机器学习系统中包括3个子节点,对数据在子节点与中心节点中的处理与传输进行示例性的说明。
[0102]
请参考图6。在中心节点中可以设置有中心融合模块,在各个子节点中可以设置有学习模块和本地融合模块。
[0103]
以子节点1为例。子节点在执行本技术实施例提供的机器学习方法时,其中的学习模块可以用于对数据集进行本地训练,以获取二值化的本地参数。子节点1可以将该本地参数发送给中心节点。类似的,子节点2和子节点3也可以将其分别对应的本地参数发送给中心节点。中心节点中的中心融合模块,可以用于将接收到的所有本地参数进行融合,以获取
融合参数。接着,中心节点可以将该融合参数分别下发给子节点1-子节点3。在子节点1中,本地融合模块可以用于根据接收到的融合参数,对本地训练模型进行更新,以获取基于融合参数的本地训练模型。类似的,在子节点2中,本地融合模块可以用于根据接收到的融合参数,对本地训练模型进行更新,以获取基于融合参数的本地训练模型。在子节点3中,本地融合模块可以用于根据接收到的融合参数,对本地训练模型进行更新,以获取基于融合参数的本地训练模型。
[0104]
在一些实现方式中,该本地融合模块可以是不包括在子节点中,而是独立的模块。例如,结合图7,在本地融合模块独立于各个子节点时,则中心节点可以将融合参数只发送给本地融合模块,本地融合模块可以用于对本地训练模型进行更新,并将更新后的本地训练模型分别下发给子节点1-子节点3。由此,可以降低对子节点的性能要求,同时由于中心节点只需要将融合参数发送给本地融合模块,因此可以降低中心节点的信令开销。
[0105]
在本技术的另一些实现方式中,本地融合模块也可是设置在部分子节点中的。比如,结合图8。其中,子节点3中集成有本地融合模块,而子节点1和子节点2的本地融合模块可以独立于子节点而单独设置的。在该架构下,中心节点可以在获取融合参数后,分别将该融合参数发送给与子节点1和子节点2对应的本地融合模块,以及子节点3。这样,子节点1和子节点2对应的本地融合模块可以将基于融合参数更新的本地训练模型下发给子节点1和子节点2。而对于子节点3,可以根据接收到的融合参数,使用其中集成的本地融合模块更新本地训练模型,由此获取更新后的本地训练模型。
[0106]
理解的是,在本示例中示出如图6、图7以及图8的及其学习系统的组成仅为一种示例,在本技术的另一些实现中,该系统中也可以包括多个独立配置的本地融合模块。比如,以在系统中配置有5个子节点(如子节点1-子节点5),以及3个本地融合模块(如本地融合模块1-本地融合模块3)为例。在一些场景下,本地融合模块1可以为子节点1和子节点2提供本地融合服务,本地融合模块2可以为子节点3和子节点4提供本地融合服务,本地融合模块3可以为子节点5提供本地融合服务。在另一些场景下,本地融合模块与子节点的对应关系也可以进行重新配置,比如,本地融合模块1可以为子节点1和子节点3提供本地融合服务,本地融合模块2可以为子节点2、子节点5提供本地融合服务,本地融合模块3可以为子节点4提供本地融合服务。当然,在一些场景下,也可以通过3个本地融合模块中的一个或部分向子节点提供本地融合服务,比如本地融合模块1可以为子节点1和子节点3提供本地融合服务,本地融合模块2可以为子节点2、子节点5以及子节点4提供本地融合服务,本地融合模块3则可以处于休眠等不工作的状态。
[0107]
为了便于说明,以下以本地融合模块集成在子节点中为例。图9示出了本技术实施例提供的一种机器学习方法的逻辑示意。如图9所示,该方法可以包括:
[0108]
s901、子节点1进行本地学习。
[0109]
s902、子节点1获取本地参数。
[0110]
s903、子节点1将本地参数发送给中心节点。
[0111]
可以理解的是,结合上述说明,对于机器学习系统中的其他子节点,也可分别执行上述s901-s903,这样中心节点就可以获取n个本地参数。其中,在一些实现中,该本地参数可以为二值化的本地参数。
[0112]
s904、中心节点对n个本地参数进行融合。
[0113]
s905、中心节点获取融合参数。
[0114]
s906、中心节点将融合参数发送给子节点1。
[0115]
s907、子节点1根据融合参数,更新本地模型参数。
[0116]
可以理解的是,对于机器学习系统中的其他子节点,中心节点也可对应执行上述s906,而对应的子节点也可执行上述s907,以便对其本地训练模型进行更新。
[0117]
可以看到,本技术实施例提供的机器学习方法,由于子节点可以将二值化的本地参数发送给中心节点,而不需要将具有较大数据量的高精度本地参数发送给中心节点,因此能够显著地降低子节点与中心节点之间的通信压力。由于传输的数据量很少,因此耗时也相应降低,由此能够提升本地训练在整个学习时长的占比,由此提高学习效率。另外,在本技术的一些实现方式中,中心节点可以通过广播的方式将融合参数下发给各个子节点,这样中心节点就可以不需一一向各个子节点发送融合参数,由此节省中心节点的信令开销。
[0118]
需要说明的是,在本技术实施例中,为了能够适应不同场景中的学习需求,本技术实施例在如图9所示的方法中,还提供三种模式,使得机器学习系统可以根据对于学习效率以及学习准确度不同的场景下,选取对应的模式,达到快速收敛或者高准确度学习的效果。
[0119]
模式1:上传模型参数时采用二值化参数,中心节点模型下传时采用高精度参数。
[0120]
由于下传采用的高精度参数,因此,子节点可以更加准确地更新获取本地训练模型。而由于上传采用的二值化参数,因此整个学习效率依然显著高于现有的fl架构中上下行都采用高精度参数的方案。该模式可以应用于对于学习效率和学习准确度都有一定要求的场景中。
[0121]
模式2:上传、下传模型参数时均采用二值化参数。
[0122]
由于上传和下传的模型参数均为二值化参数。因此,系统的数据传输压力最小。该模式可以应用于对于学习效率要求较高的场景中。
[0123]
模式3:上传、下传模型参数时均采用高精度参数。
[0124]
由于下传采用的高精度参数,因此,子节点可以更加准确地更新获取本地训练模型。该模式可以应用于对于学习准确度要求较高的场景中。
[0125]
应当理解的是,结合图6和图9,在每一轮迭代中,子节点基于现有模型和本地数据训练本地模型并上传给中心服务器。对于每个子节点,(i=1,2,

,m,m为子节点个数)表示第i个子节点的全部二值化参数。中心服务器接收到来自所有子节点上传的本地参数后,执行中心模型融合方法得到中心端参数并将以广播的形式下发给所有子节点。子节点收到后,立即使用本地模型融合方法将收到的与本地高精度参数wi进行融合。融合后,子节点重新进行二值量化然后开始下一轮的本地训练。中心模型融合方法和本地模型融合方法根据参数传输模式(如模式1-模式3)的不同而有所差别。
[0126]
以下对采用各个模式时,中心融合以及本地融合的具体计算过程进行说明。
[0127]
1、对于模式1中的中心融合:
[0128]
中心服务器收到所有子节点上传的二值化参数后对其进行加权平均,得到中心端
的参数需要说明的是,当每个节点上的数据数量相等时,这一融合公式变为此时,的取值本身并没有实际意义。但在已知m的情况下,可以反映出对于某一个参数,有多少节点的高精度参数为正或负。例如,当节点数m=100时,从中心服务器收到的中,某一位置的参数值为-0.2就意味着在这一位置,有60个子节点高精度参数为负,上传了-1,其余40个子节点高精度参数为正,上传了 1。而当每个节点上数据数量不等时,由于系统节点众多,可以认为反映了所有节点的数据中,给出正或负累积梯度的比例,考虑到对于节点i,本地数据共占比因此对于每个节点,可以视为总节点数且每个节点数据量相等的情况。因为此时,对于每个节点的等效m

均不同,但并不影响具体实施细节。为了简化计算以便于说明融合公式推导过程,以下均以各子节点数据集大小一致为例进行说明。可以理解的是,在本技术的另一些实现方式中,各个子节点数据集的大小也可以不同,在数据集大小不一致时,其计算过程类似,此处不再赘述。
[0129]
2、对于模式1中的本地融合:
[0130]
本地融合模块收到后进行本地融合计算,因为所有子节点的高精度参数均为基于本地数据集累积而得,因而具有强相关性,又因为节点数众多,所以可以假设所有节点的高精度参数近似为同一正态分布的采样值。进一步地,假设这个正态分布的协方差矩阵为对角阵,即不同位置参数的取值无关。由此,可以假设每一个子节点参数w
i,j
~n(μj,σj),其中,w
i,j
表示第i个节点第j个参数对应的参数值。问题转化为,在上述已知条件下,如何根据w
i,j
和的值估计出μj的值,并用估计到的值作为对所有节点的w
k,j
(k=1,2,

,m)的均值的估计,替换原来的w
i,j
。注意到各节点本地高精度参数很有可能不同,在完全独立地估计时,结果也会不一样。在本问题的分析求解过程中,为简明起见,省去所有下标,所有符号均表示第i个节点的第j个参数对应的参数。
[0131]
由本模式下中心端的计算过程可以发现,所有w》0的子节点数为首先,假设本地w》0,记为除本节点外,其他所有节点中w》0的比例。则在剩余的m-1个节点中,共有(m-1)θ个节点w
k,j
》0。则相当于对某正态分布n(μ,σ)的一组m个观察样本中,得到(m-1)θ个未知具体值的正观察值,(m-1)(1-θ)个未知具体值的负观察值,以及一个已知的正观察值w,这个正态分布的均值μ就可以认为是所有子节点对于本地w》0观察值的均值,可以一定程度反映出这一点的全局累积梯度大小。因此,可以用最大似然原则求出最大似然值。
[0132]
问题构建为:
[0133]
[0134]
其中,为正态分布n(μ,σ)的概率密度函数,用φ(α)=p{x》α|x~n(0,1)}表示正态分布的上分位数函数,并对上式求对数后,可得等价问题:
[0135][0136]
进行参数替换,记y=μ/σ,x=w/σ,因为w取值与优化目标无关,所以问题可以再次转化为:
[0137][0138]
为了解决上述优化问题,首先对x进行优化求解,由于w》0的约束,必然有x》0。此时,将所有与x相关的项提取出来,可以得到如下子问题:
[0139][0140]
可以很容易通过求导解得将其带入原问题后,问题可等价转换为:
[0141][0142]
优化目标是一个关于y的先增后减函数。这一函数的取值只与θ和m有关,由于m是可以事先知道的,所以可以事先绘制出最优解与θ之间的关系曲线,用特定函数拟合此关系,保存在本地,降低求解优化问题的复杂度。本实施例使用对数函数对这一曲线进行最小二乘拟合,其中θ的取值范围为[0,1],令c》0以确保函数有定义。这个函数就可以作为在给定的m下,关于θ的近似函数表达式,保存在本地。进而得到w》0时原问题的解:
[0143][0144]
考虑w《0的情况,同理推导,可以得到原问题的解为:
[0145][0146]
在本地融合的过程中,对于每一个w,根据上述过程中求出的对数近似表达式计算出出然后利用计算出并计算最后使用替换w。其中,α》1表示放大倍数,为事先选取的参数,一般取值在1.5到2.5
之间,(出于系统稳定性考虑,为了使系统更稳定地收敛,应当确保随着计算深入,w始终倾向于远离0,但又不能过大。由于高精度参数的绝对大小意义不大,但相对大小很重要,因而可以将所有参数等比例放大再进行限幅)。clamp(
·
)表示限幅操作,x《a时clamp(x,a,b)=a,a≤x≤b时clamp(x,a,b)=x,x》b时clamp(x,a,b)=b。
[0147]
最终,本地融合的公式如下:
[0148][0149]
根据该本地融合公式,本地融合模块就可以根据融合参数获取对本地训练模型进行更新。
[0150]
3、对于模式2中的中心融合:
[0151]
中心服务器收到所有子节点上传的二值化参数后对其求和并取符号,得到中心端的参数若这一结果恰好为0,则随机下发-1或 1。此时,的意义为对于某一个参数,所有节点数据中较大比例得到二值化参数为正,则取 1,反之则取-1。此时,这一参数只反映一个大概的整体趋势,包含的信息量较少。
[0152]
4、对于模式2中的本地融合:
[0153]
子节点收到后进行本地融合计算,使用简单的线性融合方法:后进行本地融合计算,使用简单的线性融合方法:其中sign(
·
)表示取符号函数,β在0与1之间,为事先选定的参数。
[0154]
5、对于模式3中的中心融合:
[0155]
中心服务器收到所有子节点上传的高精度参数后对其求平均,得到中心端的参数中心服务器收到所有子节点上传的高精度参数后对其求平均,得到中心端的参数此时,的意义为所有子节点依据本地数据集计算出的加权平均累积梯度。
[0156]
6、对于模式3中的本地融合:
[0157]
子节点收到后直接用更新本地高精度参数,即
[0158]
此模式与传统fl框架下的融合方法基本相同,区别仅在于本地模型为bnn,训练前向计算和推演复杂度相比于普通神经网络更低。
[0159]
基于上述对1-6的说明,可以看到,对于不同的传输模式,可以采用对应的本地融合以及中心融合方法,以实现对应的学习计算顺利进行。
[0160]
如前述说明,在不同的实现场景中,可以基于如图9所示的方法中,选取上述模式1或者模式2或者模式3中的任意一种作为传输方式,以此获取对应的有益效果。需要说明的是,无论采用上述模式1或模式2或模式3,由于子节点本地进行本地训练时都采用的bnn进行,因此相比于现有的fl架构能够更快地迭代获取结果。
[0161]
在本技术的另一些实现方式中,还可以结合上述模式1和模式2,或者模式1和模式3,或者模式2和模式3,或者模式1和模式2以及模式3实现如图9所示的方法。
[0162]
示例性的,可以在开始学习时,采用模式2进行。这样,虽然不能获取准确度较高的
训练结果,但是能够使得前面几轮的学习快速收敛。可以理解的是,一般而言,一个学习过程需要多轮学习才能完成。而在前面几轮的学习过程中,由于模型中的参数大概率会在后续的学习过程中发生变化,因此,对于前面几轮学习的准确度要求并不高。而如果能够提升前面几轮学习的收敛速度,那么会对整体的学习效率提升有较大贡献。在准确度达到一定程度时,可以采用模式1继续学习,由此适当提升参数的准确度。在准确度提升到一定程度时,可以采用模式3继续学习,由此获取准确度最高的结果。
[0163]
作为一种示例,表1示出了一种准确度与模式选择的对应关系。
[0164]
表1
[0165][0166]
如表1所示,在准确度小于0.65时,那么中心节点确定继续采用模式2进行学习。在准确度处于[0.65,0.8]之间时,中心节点可以确定采用模式1进行学习。而在准确度大于或等于0.8时,则中心节点可以确定通过模式3进行学习。
[0167]
其中,准确度可以是中心节点根据各个子节点上传的准确度计算获取的。比如,中心节点可以根据计算获取系统的准确度,并根据表1所示的对应关系确定传输模式。需要说明的是,表1中的0.65/0.8均为一种阈值的设置示例,在本技术的另一些实现方式中,该阈值还可以是被设置为其他值,或者根据环境灵活调整的。
[0168]
需要说明的是,对于各个子节点,可以通过其中存储的测试集,对与本地参数对应的本地训练模型进行校验,由此获取对应的准确度,并将该准确度发送给中心节点。在本技术的另一些实现方式中,校验获取准确度的操作也可以是在中心节点处完成的。比如,在中心节点中可以存储有与本地节点对应的训练模型。中心节点可以在接收到本地节点发送的本地参数后,根据该本地参数更新训练模型,并基于更新的训练模型以及中心节点中存储的测试集,对该本地参数对应的准确度进行校验,由此即可获取对应的准确度。类似的,对于其他节点上传的本地参数,中心节点也可获取对应的准确度。由此,中心节点就可以基于各个本地参数对应的准确度,计算获取对应的系统的准确度,进而确定数据的传输模式。
[0169]
在本技术的另一些实现方式中,该准确度也可以是中心节点根据完成中心融合后的融合参数,以及训练模型,结合测试集或者验证数据集进行校验,获取的准确度。在不同的实现方式中,准确度的确定方法可以灵活确定,本技术实施例对此不作限制。
[0170]
以上说明中,是以通过准确度选取传输模式为例进行说明的,在本技术的另一些实现方式中,中心节点还可以根据其他方法确定传输模式。比如,中心节点可以根据迭代轮数n确定传输模式。表2示出了一种可能的迭代轮数n与传输模式的对应关系。
[0171]
表2
[0172]
迭代轮数n参数传输模式n≤5模式25《n≤50模式1
n》50模式3
[0173]
如表2所示,在迭代轮数在5以内时,中心节点可以确定当前学习中,提高收敛速度的需求更高,则可以采用模式2进行学习。在迭代轮数大于5轮且在50轮之内时,则中心节点可以确定当前学习中需要适当提高准确度,继而采用模式1进行学习。而在迭代轮数大于50时,则中心节点可以认为学习即将结束,需要以最高的准确度进行参数传输,即采用模式3进行学习。
[0174]
需要说明的是,在本技术的一些实现方式中,当三种模式发生切换时,中心节点可以指示各子节点调整参数传输模式。例如增加参数传输模式字段来指示,三种模式可用2比特来指示,如00指示下轮参数传输方式为模式1、01指示模式2、10指示模式3。参数传输模式字段可随中心融合模型一起下发,或用专用控制信道下发。
[0175]
通过上述方案说明,可以理解的是,在结合图4-图8所示的机器学习系统中,采用如图9所示的机器学习方法,能够显著地增加系统学习效率。同时结合模式1、模式2以及模式3的选择使用,能够结合当前学习过程中对于准确度和收敛速度的不同需求,自适应地选取对应的模式,以获取最优的学习结果。
[0176]
为了说明本技术实施例提供的方案能够达到的效果,以下结合仿真数据,对本技术所述方案所能够达到的效果进行示例性说明。
[0177]
以mnist手写数字识别为例,在仿真中,使用了由两层卷积层和两层全连接层组成的4层卷积神经网络,mnist数据集中的训练集被均匀分布在100个子节点上,每个节点共有600对数据,其中每类数据60对,测试集只保存在中心服务器上。最终的训练集相关结果为100个节点的结果的均值,测试集相关的结果为中心服务器基于本地参数二值化后的结果进行的。高精度参数采用32比特量化。
[0178]
使用4层卷积神经网络进行训练,具体结构为:3*3*16卷积层,归一化层,2*2最大池化层,tanh激活函数,3*3*16卷积层,归一化层,2*2最大池化层,tanh激活函数,784*100全连接层,归一化层,tanh激活函数,100*10全连接层,softmax激活函数,最后使用交叉熵损失函数。使用了adam梯度更新方法,学习率η的初始值为0.05,随后每30次迭代降低一次,依次为0.02,0.01,0.005,0.002。
[0179]
模式1的本地融合方法中,假设m=100,验证用对数函数拟合函数的曲线,绘制出曲线和使用插值得到的对数近似曲线,其结果如图10所示。可以看到,通过模式1所示的本地融合方法所绘制的近似曲线(表达式为),与真实绘制的真实曲线基本重合。因此,通过上述模式1的本地融合方法能够较好地模拟真实情况。
[0180]
图11示出了本发明应用在mnist手写数字识别数据集上时,测试集准确率随时间变化的曲线。其中,集中式训练,即将所有的数据收集至一个中心节点进行训练,此曲线作为基线进行对比。可以看出,在测试集准确率上,模式1、3均可以达到接近集中式的效果,模式3在训练效果上有微弱的优势,但模式1每次迭代所需的通信量很小,消耗的通信代价远小于模式3,且模式3表现不是很稳定。在模式2中,虽然最终性能较差,但由于所需的通信量极低,在训练的早期具有较强的竞争力,随着训练的深入,性能无法与前两种模式相比。总体来说,本发明中的模式1适合大多数实际情况,模式2适合在训练早期,或者通信资源非常
紧张、且对学习效果要求很低时应用,模式3适合后期对基本训练好的模型进行微调。
[0181]
以下(表3)给出使测试集准确率首次达到90%和95%所需的每个子节点计算量和通信量对比。其中,选择α=1.5和β=0.3时的结果作为模式1和模式2的结果,计算量以每个子节点进行一次前向计算和反向传播为1,则每次子节点训练需要进行10次前向计算和反向传播,需要计算量为10。
[0182]
表3
[0183][0184]
可以看到,系统总参数个数为82242,则系统工作在模式1时每次需要上传10.04kb、下传66.70kb数据;工作在模式2时每次需要上传、下传各10.04kb数据;模式2因为无法达到95%准确率,故数据空缺;工作在模式3时每次需要上传、下传各321.26kb数据。相比于普通联邦学习的训练模式(模式3),本发明中的模式1方法每次迭代所需的通信量大大减少,本方法大大地降低了分布式机器学习系统的通信量,进而可以大幅度减少分布式机器学习任务所需的总时间。
[0185]
在各个子节点中的数据集非独立同分布场景下,本技术实施例所提供的机器学习方法的仿真结果如下:
[0186]
同样以mnist手写数字识别为例,在仿真中,使用了由两层卷积层和两层全连接层组成的4层卷积神经网络,网络结构与2.4.1节相同,学习率η的初始值为0.02,随后每30次迭代降低一次,依次为0.01,0.005,0.002。mnist数据集中的训练集被不均匀分布在100个子节点上,具体分布方法为:首先将数据集按照类型分为10份,再将每份均分为100份,得到1000个子数据集,将这1000个子数据集随机分配给100个子节点,每个子节点被随机分配到10个子数据集。测试集只保存在中心服务器上。
[0187]
图12为本发明应用在非独立同分布mnist手写数字识别数据集上时,测试集准确率变化的曲线。其中混合模式中模式切换采用迭代轮数进行判断,初始使用模式2进行训练,5次迭代后改为模式1,再经过50次迭代后改为模式3。在这一情况下,因为数据集非独立同分布的特点,系统会有一定的性能损失。可以看出,混合模式最终也可以达到较好的效果。
[0188]
表4给出了三种模式和混合模式连续5次达到一定准确率所需的通信和计算代价。
[0189]
表4
[0190][0191]
结合表4的仿真结果,可以看出,单纯使用模式1和模式2,较难达到85%的准确度需求,混合模式可以达到,且通信开销与模式3相比更有优势。
[0192]
上述主要从子节点和中心节点的角度对本技术实施例提供的方案进行了介绍。为了实现上述功能,其包含了执行各个功能相应的硬件结构和/或软件模块。本领域技术人员应该很容易意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,本技术能够以硬件或硬件和计算机软件的结合形式来实现。某个功能究竟以硬件还是计算机软件驱动硬件的方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本技术的范围。
[0193]
本技术实施例可以根据上述方法示例对其中涉及的设备进行功能模块的划分,例如,可以对应各个功能划分各个功能模块,也可以将两个或两个以上的功能集成在一个处理模块中。上述集成的模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。需要说明的是,本技术实施例中对模块的划分是示意性的,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式。
[0194]
请参考图13,为本技术实施例提供的一种机器学习装置1300,该装置可以应用于子节点,子节点中设置有二值化神经网络模型bnn,该装置包括:获取单元1301,用于对采集获取的本地数据集,进行基于bnn的机器学习,获取与本地数据集对应的本地模型参数。发送单元1302,用于向中心节点发送第一消息,第一消息包括本地模型参数。
[0195]
在一种可能的设计中,第一消息中包括的本地模型参数是二值化的本地模型参数。
[0196]
在一种可能的设计中,该装置还包括:接收单元1303,用于从中心节点接收融合参数,融合参数是中心节点根据本地模型参数融合获取的。融合单元1304,用于对融合参数以及本地模型参数进行融合,以获取更新后的本地模型参数。
[0197]
在一种可能的设计中,融合参数是二值化的模型参数,或者,融合参数是高精度的模型参数。
[0198]
在一种可能的设计中,第一消息还包括:与本地模型参数对应的准确度信息。获取单元1301,还用于根据本地模型参数,以及测试数据集,校验获取准确度信息。
[0199]
在一种可能的设计中,该装置还包括:学习单元1305,用于根据更新后的本地模型参数,基于bnn继续进行机器学习。
[0200]
需要说明的是,上述方法实施例涉及的各步骤的所有相关内容均可以援引到对应功能模块的功能描述,在此不再赘述。
[0201]
请参考图14,为本技术实施例提供的一种机器学习装置1400,该装置应用于中心节点,装置包括:接收单元1401,用于接收分别来自n个子节点的n个第一消息,第一消息包括本地模型参数,本地模型参数是二值化的本地模型参数。n为大于或等于1的整数。融合单元1402,用于对n个第一消息包括的本地模型参数进行融合,获取融合参数。发送单元1403,用于向m个子节点发送第二消息,第二消息包括融合参数,m为大于或等于1的正整数。m个子节点包括在n个子节点中。
[0202]
在一种可能的设计中,融合单元1402,具体用于对n个本地模型参数,进行加权平均,获取融合参数。
[0203]
在一种可能的设计中,第二消息包括的融合参数是高精度的融合参数,或者,第二消息包括的融合参数是二值化的融合参数。
[0204]
在一种可能的设计中,该装置还包括:确定单元1404,用于根据n个第一消息,确定系统准确度信息,中心节点根据系统准确度信息,确定第一消息中包括的融合参数为高精度的融合参数或者二值化的融合参数。
[0205]
在一种可能的设计中,第一消息还包括:准确度信息。准确度信息与第一消息中包括的本地模型参数在对应子节点处校验获取的准确度对应。确定单元1404,具体用于根据n个消息中包括的准确度信息,确定系统准确度信息。
[0206]
在一种可能的设计中,确定单元1404,用于在系统准确度信息小于或等于第一阈值时,确定融合参数为二值化的融合参数。确定单元1404,还用于在系统准确度信息大于或等于第二阈值时,确定融合参数为高精度的融合参数。
[0207]
在一种可能的设计中,发送单元1403,用于在迭代轮数小于或等于第三阈值时,向m个子节点发送包括二值化的融合参数的第二消息。发送单元1403,还用于在迭代轮数大于或等于第四阈值时,向m个子节点发送包括高精度的融合参数的第二消息。
[0208]
在一种可能的设计中,发送单元1403,具体用于通过广播,向m个子节点发送第二消息。
[0209]
需要说明的是,上述方法实施例涉及的各步骤的所有相关内容均可以援引到对应功能模块的功能描述,在此不再赘述。
[0210]
图15示出了的一种子节点1500的组成示意图。如图15所示,该子节点1500可以包括:处理器1501和存储器1502。该存储器1502用于存储计算机执行指令。示例性的,在一些实施例中,当该处理器1501执行该存储器1502存储的指令时,可以使得该子节点1500执行上述实施例中任一种所示的数据处理方法。
[0211]
需要说明的是,上述方法实施例涉及的各步骤的所有相关内容均可以援引到对应功能模块的功能描述,在此不再赘述。
[0212]
图16示出了的一种芯片系统1600的组成示意图。该芯片系统可以应用于本技术实施例涉及的任意一个子节点中。该芯片系统1600可以包括:处理器1601和通信接口1602,用于支持相关设备实现上述实施例中所涉及的功能。在一种可能的设计中,芯片系统还包括存储器,用于保存子节点必要的程序指令和数据。该芯片系统,可以由芯片构成,也可以包含芯片和其他分立器件。需要说明的是,在本技术的一些实现方式中,该通信接口1602也可
称为接口电路。
[0213]
需要说明的是,上述方法实施例涉及的各步骤的所有相关内容均可以援引到对应功能模块的功能描述,在此不再赘述。
[0214]
图17示出了的一种中心节点1700的组成示意图。如图17所示,该中心节点1700可以包括:处理器1701和存储器1702。该存储器1702用于存储计算机执行指令。示例性的,在一些实施例中,当该处理器1701执行该存储器1702存储的指令时,可以使得该中心节点1700执行上述实施例中任一种所示的数据处理方法。
[0215]
需要说明的是,上述方法实施例涉及的各步骤的所有相关内容均可以援引到对应功能模块的功能描述,在此不再赘述。
[0216]
图18示出了的一种芯片系统1800的组成示意图。该芯片系统可以应用于本技术实施例涉及的任意一个中心节点中。该芯片系统1800可以包括:处理器1801和通信接口1802,用于支持相关设备实现上述实施例中所涉及的功能。在一种可能的设计中,芯片系统还包括存储器,用于保存中心节点必要的程序指令和数据。该芯片系统,可以由芯片构成,也可以包含芯片和其他分立器件。需要说明的是,在本技术的一些实现方式中,该通信接口1802也可称为接口电路。
[0217]
需要说明的是,上述方法实施例涉及的各步骤的所有相关内容均可以援引到对应功能模块的功能描述,在此不再赘述。
[0218]
在上述实施例中的功能或动作或操作或步骤等,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件程序实现时,可以全部或部分地以计算机程序产品的形式来实现。该计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行计算机程序指令时,全部或部分地产生按照本技术实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或者数据中心通过有线(例如同轴电缆、光纤、数字用户线(digital subscriber line,dsl))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包括一个或多个可以用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质(例如,软盘、硬盘、磁带),光介质(例如,dvd)、或者半导体介质(例如固态硬盘(solid state disk,ssd))等。
[0219]
尽管结合具体特征及其实施例对本技术进行了描述,显而易见的,在不脱离本技术的精神和范围的情况下,可对其进行各种修改和组合。相应地,本说明书和附图仅仅是所附权利要求所界定的本技术的示例性说明,且视为已覆盖本技术范围内的任意和所有修改、变化、组合或等同物。显然,本领域的技术人员可以对本技术进行各种改动和变型而不脱离本技术的精神和范围。这样,倘若本技术的这些修改和变型属于本技术权利要求及其等同技术的范围之内,则本技术也意图包括这些改动和变型在内。
再多了解一些

本文用于企业家、创业者技术爱好者查询,结果仅供参考。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献