一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

用于训练潜在变量模型的方法和系统与流程

2022-12-09 23:43:58 来源:中国专利 TAG:


1.本公开总体上涉及用于计算机学习的系统和方法,其可以提供改进的计算机性能、特征和用途。更特别地,本公开涉及用于改进性能的学习深度潜在变量模型的系统和方法。


背景技术:

2.深度生成模型在许多领域取得了巨大的成功,诸如图像生成、图像恢复、图像表示、图像解耦、异常检测等。这些模型通常包括简单和富有表现力的生成器网络,它们是潜在变量模型,假设每个观察到的示例是由潜在变量的低维向量生成的,并且潜在向量遵循非信息性先验分布,诸如高斯分布。由于高维视觉数据(例如图像)通常位于嵌入在高维空间中的低维流形上,因此对于无监督表示学习而言,学习视觉数据的潜在变量模型在计算机视觉领域非常重要。挑战主要来自对每个观察的潜在变量的推断,这通常依赖于马尔科夫链蒙特卡洛(markov chain monte carlo,mcmc)方法从难以解析的后验分布(即给定观察样本的潜在变量的条件分布)中抽取适当样本。由于潜在变量的后验分布是由高度非线性的深度神经网络参数化的,因此基于mcmc的推断可能会遇到非收敛和效率低下的问题,从而影响模型参数估计的准确性。
3.因此,需要以提高的效率学习深度潜在变量模型的系统和方法。


技术实现要素:

4.第一方面,提供一种用于训练潜在变量模型的计算机实现方法,包括:
5.通过短期马尔科夫链蒙特卡洛(mcmc)推断从潜在变量模型生成的多个观察示例中的每个的潜在向量,以获得推断的潜在向量群;
6.通过优化传输(ot)校正,将推断的潜在向量群移动到先验分布;以及基于校正后的潜在向量和相应的观察示例,通过梯度下降更新潜在变量模型的模型参数。
7.第二方面,提供一种用于训练潜在变量模型的系统,包括:
8.一个或多个处理器;以及
9.非暂时性计算机可读介质,包括一组或多组指令,当由一个或多个处理器中的至少一个执行时,所述一组或多组指令导致执行如第一方面所述的方法的步骤。
10.第三方面,提供一种包括一个或多个指令序列的非暂时性计算机可读介质,当由至少一个处理器执行时,所述指令序列引起如第一方面所述的用于训练潜在变量模型的方法的步骤。
11.第四方面,提供一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时使得所述处理器执行如第一方面所述的方法。
12.短期mcmc的实施例,诸如短期朗之万动力学,在本文中用作近似的基于流的推断引擎。非收敛短期朗之万动力学的输出分布中存在的偏差可以通过最优传输(ot)来校正,ot旨在以最小的传输成本将有限步mcmc产生的偏差分布变换为先验分布。实验结果验证了
ot校正对短期mcmc的有效性,并证明通过公开的策略训练的潜在变量模型在图像重建、图像生成和异常检测方面表现优于变分自动编码器。
附图说明
13.将参考本公开的实施例,其示例可以在附图中示出。这些图旨在说明性而非限制性的。尽管在这些实施例的上下文中一般地描述了本公开,但应当理解,其并非旨在将本公开的范围限制于这些特定实施例。图中的项目可能不是按比例的。
14.图1描绘了根据本公开的实施例的深度潜在变量模型的框图。
15.图2描述了训练深度潜在变量模型的不同方法。
16.图3描绘了根据本公开的实施例的长期和短期mcmc推断框架。
17.图4描绘了根据本公开的实施例的使用短期mcmc推断和最优传输校正来学习深度潜在变量模型的过程。
18.图5描绘了根据本公开的实施例的通过最优传输校正的输出分布的变化。
19.图6描绘了根据本公开的实施例的短期mcmc推断的过程。
20.图7描绘了根据本公开的实施例的最优传输校正的过程。
21.图8描绘了根据本公开的实施例的在不同迭代和先验分布处从边缘分布q
θ
(zk)采样的潜在代码的可视化。
22.图9描绘了根据本公开的实施例的由对来自数据集a的“0”类和“1”类的图像训练的不同模型的z的输出边缘分布。
23.图10a描绘了根据本公开的实施例的控制最优传输的百分比的超参数α在不同迭代上对最优传输成本的影响。
24.图10b描绘了根据本公开的实施例的控制最优传输的百分比的超参数α在不同迭代上对均方误差(mean squared error,mse)损失的影响。
25.图10c描绘了根据本公开的实施例的控制最优传输的百分比的超参数α在不同迭代上对frechet起始距离(frechet inception distance,fid)的影响。
26.图11描绘了根据本公开的实施例的计算设备/信息处理系统的简化框图。
具体实施方式
27.在以下描述中,出于解释的目的,阐述了具体细节以便提供对本公开的理解。然而,对于本领域的技术人员来说显而易见的是,可以在没有这些细节的情况下实施本公开。此外,本领域技术人员将认识到,下文描述的本公开的实施例可以以多种方式实现,诸如有形计算机可读介质上的过程、装置、系统、设备或方法。
28.图中所示的组件或模块是对本公开的示例性实施例的说明并且旨在避免混淆本公开。还应该理解,在整个讨论中,组件可以被描述为单独的功能单元,这些功能单元可以包括子单元,但是本领域的技术人员将认识到各种组件或其部分可以被划分为单独的组件或可以集成在一起,包括,例如,在单个系统或组件中。应该注意,本文讨论的功能或操作可以实现为组件。组件可以以软件、硬件或其组合来实现。
29.此外,图中的组件或系统之间的连接不旨在限于直接连接。相反,这些组件之间的数据可以由中间组件修正、重新格式化或以其他方式改变。此外,可以使用更多或更少的连
接。还应注意,术语“耦接”、“连接”、“通信耦接”、“接口”、“接入”或其任何派生词汇应理解为包括直接连接、通过一个或多个中间设备的间接连接和无线连接。还应注意,任何通信,诸如信号、响应、回复、确认、消息、查询等,都可以包括一个或多个信息交换。
30.在说明书中对“一个或多个实施例”、“优选实施例”、“实施例”、“一些实施例”等的引用意味着结合实施例描述的特定特征、结构、特性或功能被包括在本发明的至少一个实施例且可在多于一个实施例中。此外,在说明书的多个地方出现的上述短语不一定都指同一实施例或多个实施例。
31.在说明书的不同地方使用某些术语是为了说明,不应解释为限制。服务、功能或资源不限于单个服务、功能或资源;这些术语的使用可能是指一组相关的服务、功能或资源,这些服务、功能或资源可能是分布式的或集成的。术语“包括”、“包含”、“具有”和“含有”以及它们的任何变体应理解为开放式术语,并且以下任何列表都是示例,并不意味着限于所列项目。“层”可以包括一个或多个操作。词语“最优”、“优化”、“最佳化”等指的是结果或过程的改进,并且不要求指定的结果或过程已经达到“最优”或峰值状态。存储器、数据库、信息库、数据存储、表格、硬件、高速缓存等的使用在本文中可以用来指代可以输入或以其他方式记录信息的一个或多个系统组件。
32.在一个或多个实施例中,停止条件可以包括:(1)已经执行了设定次数的迭代;(2)已达到一定的处理时间;(3)收敛(例如,连续迭代之间的差小于第一阈值);(4)发散(例如,性能恶化);(5)已达到可接受的结果;以及(6)所有的数据都已经处理。
33.本领域技术人员应当认识到:(1)可以选择性地执行某些步骤;(2)步骤可以不限于此处设置的特定顺序;(3)某些步骤可以按不同的顺序执行;(4)某些步骤可以同时进行。
34.本文使用的任何标题仅用于组织目的,不应用于限制说明书或权利要求的范围。此专利文件中提及的每个参考文献/文件通过引用整体并入本文。
35.需要说明的是,本文所提供的任何实验和结果均以说明的方式提供,并且是在特定条件下使用一个或多个特定实施例执行的;因此,这些实验及其结果均不应用于限制本专利文件的公开范围。
36.a.概述
37.深度生成模型在诸如图像生成、图像恢复、图像表示、图像解耦、异常检测等许多领域都取得了巨大的成功。此类模型的典型代表主要包括简单且富有表现力的生成器网络,它们是潜在变量模型,假设每个观察到的示例是由潜在变量的低维向量生成的,并且潜在向量遵循非信息性先验分布,诸如高斯分布。图1描绘了根据本公开的实施例的深度潜在变量模型100的框图。该模型包括多个卷积层,例如第一层120、第二层130、第三层140和第四层150,以通过非线性变换g(z)将潜在向量110投影和重塑到观察到的样本i 160,例如图像。模型中的一个或多个层可以实现具有所需操作的步长的卷积。
38.由于高维视觉数据(例如图像)通常位于嵌入在高维空间中的低维流形上,因此对于无监督表示学习而言,学习视觉数据的潜在变量模型在计算机视觉领域非常重要。然而,由于g的非线性参数化,学习这样的模型具有挑战性。
39.图2描绘了训练深度潜在变量模型210的不同方式。基于马尔科夫链蒙特卡洛(mcmc)的最大似然估计(maximum likelihood estimate,mle)210的挑战主要来自于对每个观察的潜在变量的推断,这通常依赖于mcmc方法从难以分析的后验分布(即,给定观察后
的示例的潜在变量的条件分布)中抽取适当样本。由于潜在变量的后验分布是由高度非线性的深度神经网络参数化的,因此基于mcmc的推断可能会遇到非收敛和效率低下的问题,从而影响模型参数估计的准确性。
40.变分自动编码器(variational auto-encoder,vae)和生成对抗网络(generative adversarial network,gan)是目前训练深度潜在变量模型的流行方法。这两个模型通过招募额外的模型来辅助训练而训练生成器,并且在测试中将忽略它。为了避免从后验中进行低效的mcmc采样,变分推断通过经由易处理的网络近似难以处理的后验成为一种有吸引力的替代方案。尽管vae越来越流行和受欢迎,但它的缺点也越来越明显。首先,它通过外部前馈推断模型220参数化内部迭代推断过程。由于重新参数化而产生的这些额外参数必须与生成器网络的参数一起估计。其次,这种联合训练是通过最大化变分下限来完成的。因此,vae的准确性在很大程度上取决于作为真实后验分布的近似的推断模型的准确性。只有当推断和后验分布之间的kullback-leibler(kl)散度等于0时,变分推断才等效于期望的最大似然估计。此目标在实践中通常是不可行的。第三,在设计vae的推断模型时需要付出额外的努力,特别是对于具有复杂的潜在变量的依赖结构的生成器,例如,一些提出了具有多层潜在变量的自上而下生成器,一些提出了具有潜在变量时间序列的动态生成器。设计推断模型来推断上述模型的潜在变量并不是一项简单的任务。推断模型的任意设计不能保证性能。gan训练方法除了生成器之外还涉及鉴别器230,因此在训练期间具有两组参数。在训练过程中可能会发生模型崩溃。此外,很难为gan方法设计有效的推断模型。
41.在本公开中,完全放弃了重新参数化推断过程的想法。而是,公开了用于训练深度潜在变量模型的基于mcmc的推断的实施例。具体来说,短期mcmc的实施例,诸如短期朗之万动力学,用于在训练期间执行潜在向量的推断。然而,考虑到每次迭代中有限步朗之万动力学的收敛性可能是一个问题,最优传输(optimal transport,ot)的实施例被用于校正可能存在于这种短期mcmc中的偏差。ot可用于将任意概率分布变换为具有最小传输成本的期望分布。因此,ot成本可用于衡量两个概率分布之间的差。在本公开的一个或多个实施例中,短期mcmc被视为其参数来自潜在变量模型的学习的流模型。短期mcmc的偏差可以通过执行从由短期mcmc得到的结果分布到先验分布的最优传输来校正。这样的操作是为了最小化推断分布和先验分布之间的ot成本,其中流模型中的参数被更新而不是被优化。通过校正后的推断输出,可以更准确地更新潜在变量模型的参数。
42.图3描绘了根据本公开的实施例的长期和短期mcmc推断框架。传统的长期mcmc推断框架仅涉及推断步骤310和学习步骤320。如前所述,这种基于长期mcmc的推断可能会遇到非收敛和效率低下的问题,从而影响模型参数估计的准确性。而另一方面,本文件公开了包括推断步骤330、校正步骤340和学习步骤350的短期mcmc推断框架的实施例,每个步骤的细节在图4中示出。
43.图4描绘了根据本公开的实施例的使用短期mcmc推断和最优传输校正来学习深度潜在变量模型的过程。该过程迭代一轮或多轮以下三个步骤:(1)在推断步骤405中:通过使用朗之万动力学的短期mcmc推断从潜在变量模型生成的多个观察示例中的每一个的潜在向量以获得推断的潜在向量的群。朗之万动力学从后验分布采样。(2)在校正步骤410中:将所有推断出的潜在向量的群通过最优传输校正移动到先验分布。(3)在学习步骤415中:根据校正后的潜在向量和对应的观察示例,通过梯度下降来更新模型参数。
44.使用所公开的带有ot校正的短期mcmc推断有几个优点:(1)效率:使用短期mcmc,模型的学习和推断是有效的;(2)便利性:以短期mcmc为代表的近似推断模型是自动的,无需担心单独的推断模型的设计和训练。自下而上的推断和自上而下的生成由同一组参数控制;(3)准确性:最优传输校正了非收敛的短期mcmc推断的误差,从而提高了模型参数估计的准确性。
45.本专利公开的贡献至少包括以下内容:(1)公开了通过具有ot校正的非收敛短期mcmc推断来训练深度潜在变量模型的实施例;(2)扩展了半离散ot方法的实施例,以近似推断的潜在向量与从先验分布中抽取的样本之间的一对一映射;(3)在各种实验中提供了强有力的实证结果,以验证所公开的策略对训练深度潜在变量模型的有效性。
46.b.一些相关工作
47.1.变分推断
48.vae是一种流行的学习生成器网络的方法,它通过同时训练易于处理的推断网络来近似潜在变量的难以处理的后验分布。在vae中,需要为潜在变量设计推断模型,这在具有复杂架构的生成器网络中是一项有意义的任务。而在本专利文件中,所公开的方法不依赖额外的推断模型来辅助训练。它通过朗之万采样从后验分布进行推断,然后进行最优传输校正。
49.交替反向传播算法。生成器网络的最大似然学习,包括其动态版本,可以通过交替反向传播(alternating backpropagation,abp)算法实现,而无需借助推断模型。abp算法通过交替以下两个步骤来训练生成器模型:(1)推断步骤:通过朗之万采样从后验分布推断潜在变量,以及(2)学习步骤:基于训练数据和推断的潜在变量通过梯度下降更新模型参数。这两个步骤都在反向传播的帮助下计算梯度。abp算法已成功应用于显著性检测、零次学习、解耦表示学习等。
50.2.最优传输
51.最优传输(ot)用于计算两个度量之间的距离,并且能够将源分布推向目标分布。最近,ot已广泛用于生成模型,以帮助生成高质量的样本。例如,通过将gan模型中原有的kl散度替换为w1距离,一些人提出了wasserstein gan(wgan)模型,以实现更好的收敛性并生成更高质量的样本。一些人提出了将推断模型和后验分布之间的wasserstein距离最小化的wasserstein vae。除了wasserstein距离外,最优传输还用于将简单的均匀分布传输到由自动编码器提取的复杂潜在特征分布以生成图像。
52.c.深度潜在变量模型最大似然学习的实施例
53.假设i是d维观察数据示例,诸如图像。令z为连续潜在变量的d维向量。从传统的因子分析模型进行推广,生成器网络假设观察到的示例i是通过非线性变换i=g
θ
(z) ∈从潜在向量z生成的,其中g
θ
是自上而下的卷积神经网络(有时称为反卷积神经网络),其中参数θ包括网络中所有可训练的权重和偏差项,是观察误差,id和id分别是d维和d维单位矩阵,假设d<<d。生成器网络本质上可能是非线性潜在变量模型,它定义了(i,z)的联合分布,
54.p
θ
(i,z)=p
θ
(i|z)p(z)
ꢀꢀꢀ
(1)
55.其中假设先验分布且标准偏差σ取假设值。按照贝叶斯法则,可以得到边缘分布p
θ
(i)=∫p
θ
(i,z)dz和后验分布p
θ
(z|i)=p
θ
(i,z)/p
θ
(i)。
56.给定一组训练示例{ii,i=1,...,n}~p
data
(i),其中p
data
(i)是未知数据分布。p
θ
可以通过最大化训练样本的对数似然来训练:
[0057][0058]
当训练示例数量n足够大时,这相当于kl(p
data
||p
θ
)的最小化。
[0059]
在一个或多个实施例中,等式(2)中呈现的对数似然函数的最大化可以通过迭代的梯度上升算法来实现
[0060][0061]
其中γ
t
是取决于时间t的学习率,对数概率的梯度由下式给出:
[0062][0063]
为了计算等式(4)中的需要估计根据等式(1),联合分布的对数由下式给出:
[0064][0065]
其中常数项与z或θ无关,因此其中(z)可以通过反向传播有效地计算。
[0066]
d.短期mcmc推断的实施例
[0067]
1.长期朗之万动力学实施例
[0068]
为了使用等式(3)学习模型参数θ,关键是计算等式(4)中难以处理的期望项,这可以通过首先从p
θ
(i,z)中抽取样本,然后使用蒙特卡洛样本平均来近似它而实现。给定步长s>0和初始值z0,朗之万动力学,一种基于梯度的mcmc方法,可以通过递归计算从后验密度p
θ
(z|i)生成样本
[0069][0070]
在等式(6)中,k索引了朗之万动力学的时间步骤,是随机噪声扩散。此外,其中可以通过反向传播被高效计算。
[0071]
在一个或多个实施例中,k用于表示朗之万步骤数量。当s

0且k

∞时,无论z0的初始分布是什么,zk将会收敛到后验分布p
θ
(z|i),并且成为p
θ
(z|i)的适当样本。
[0072]
2.短期朗之万动力学实施例
[0073]
使用长期mcmc训练深度潜在变量模型可能不明智或不现实。在每次迭代中,运行
有限数量的朗之万动力学步骤来推断p
θ
(z|i)似乎是可行的。因此,短期k步朗之万动力学由下式给出:
[0074]
z0~p0(z)
[0075][0076]
在一个或多个实施例中,初始分布p0被假定为高斯分布。这种动力学可以被视为条件生成器,它在条件i下将随机噪声z0变换为目标分布。变换本身也可以被视为k层残差网络,其中每一层共享相同的参数θ并且具有噪声注入。κ
θ
用于表示k步mcmc转移核。给定i的zk的条件分布是:
[0077]qθ
(zk|i)=∫p0(z0)κ
θ
(zk|z0,i)dz0ꢀꢀꢀ
(8)
[0078]
zk对应的边缘分布为
[0079]qθ
(zk)=∫q
θ
(zk|i)p
data
(i)di
ꢀꢀꢀ
(9)
[0080]
如果mcmc收敛,q
θ
(zk)应该接近于先验分布p(z),否则,它们之间存在差距。
[0081]
等式(7)也称为噪声初始化的短期mcmc,其中对于参数更新的每一步,短期mcmc从噪声分布z0~p0(z)开始。如果短期mcmc是由上一次迭代得到的推断结果初始化的,则称为持久短期mcmc。
[0082]
尽管等式(8)中的短期mcmc推断效率很高,但它可能不会收敛到真实的后验分布p
θ
(z|i)。有些人将短期mcmc视为近似推断模型,并通过变分推断优化步长s,其中步长s经由网格搜索或梯度下降进行优化,使得短期mcmc qs(z|i)(这里s是学习参数)可能最好地近似后验分布p
θ
(z|i)。
[0083]
e.具有ot校正的mcmc推断的实施例
[0084]
在一个或多个实施例中,使用最优传输来校正短期推断结果的偏差。在一个或多个实施例中,不是最小化短期推断模型和真实后验之间的差,即kl(q
θ
(zk|i)|p
θ
(z|i)),而是使用ot来最小化由短期朗之万动力学推断的潜在变量的边缘分布q
θ
(zk)与先验分布p0(z)之间的传输成本。
[0085]
1.有偏差的短期mcmc的ot校正实施例
[0086]
在一个或多个实施例中,为了学习从潜在向量z生成观察到的图像i的自上而下潜在变量模型i=g
θ
(z),迭代以下三个步骤。
[0087]
(1)推断步骤:首先通过k步短期mcmc对每个观察到的图像ii推断潜在向量,即然后对于所有观察后的数据{ii}获得推断的潜在向量的群其中其中
[0088]
(2)校正步骤:ot用于将移动到所需的先验分布,以缩小它们之间由于非收敛推断而产生的差距。图5描绘了根据本公开的实施例的通过最优传输校正的输出分布的变化。如图5所示,ot以最小移动成本将有偏差的群510重塑到先验分布520。使用更正确的推断的潜在向量,后续的参数更新可以更准确。
[0089]
(3)学习步骤:给定观察到的图像及其相应的推断的潜在向量,θ由等式(3)和等式(4)更新。随着θ训练得越来越好,推断引擎q
θ
(zk)变得更加准确,ot所做的校正也变得更小。使用ot校正的公开策略的图示在前述图3中呈现。如图3所示,将所公开的使用具有ot校正304的短期mcmc的框架与使用传统的长期mcmc推断302的框架进行比较。
[0090]
在实践中,可以在推断步骤中使用噪声初始化的短期mcmc或持久的短期mcmc。在一个或多个实验中,选择后一个是为了快速收敛。对于校正阶段,从到{zi}学习一对一的ot映射,{zi}是从先验高斯分布中采样的群,与大小相同。在每次迭代中计算最优传输在实践中是耗时且不必要的。在一个或多个实施例中,为了使整个流水线更高效,可以在每l次迭代之后执行校正步骤。得到双射ot映射后,不直接通过配对数据来更新模型,而是使用ot结果和旧的结果的混合校正以避免由于的突然变化而导致学习不稳定,即
[0091][0092]
在等式(10)中,α∈[0,1]是超参数,用于控制用于校正的ot结果的百分比。因此,可以获得校正后的配对数据以更新模型参数θ。需要说明的是,当α=0时,可以认为本公开模型实施例退化为传统的abp模型。如果α设置为1,则短期输出完全用ot结果进行校正。适度的0<α<1通常有助于将边缘分布q
θ
(zk)逐渐拉到先验分布p(z)以确保平滑校正。方法1使用用于短期mcmc推断和ot校正的详细过程总结了学习策略实施例的整个流水线,分别在图6和图7中示出。
[0093]
方法1:具有ot校正的短期mcmc推断实施例
[0094]
[0095][0096]
方法2:具有ot校正的短期mcmc推断实施例
[0097][0098]
图6描绘了根据本公开的实施例的短期mcmc推断的过程。在步骤605中,从高斯先验分布中随机采样初始分布p0。在步骤610中,使用有限步朗之万动力学实施短期mcmc推断,以从从潜在变量模型g0生成的观察示例推断多个潜在向量。在每个朗之万步骤中,短期mcmc可以由随机噪声或在前一步中获得的推断结果初始化。
[0099]
图7描绘了根据本公开的实施例的最优传输校正的过程。在步骤705中,给定多个
推断的潜在向量和从先验随机采样的多个样本(例如,高斯分布n(0,id)),使用梯度下降优化迭代地优化包括将多个推断的潜在向量映射到多个样本的多个传输路径的双射ot映射。在步骤710中,当满足停止条件时,建立双射ot映躬以获得包括多个映射的潜在向量的ot结果停止条件可以是满足迭代次数或优化梯度变得小于预定阈值。在步骤715中,通过多个映射的潜在向量和多个推断的潜在向量的混合,获得多个ot校正的潜在向量,其中映射的潜在向量的百分比由0和1之间的超参数控制。之后,多个ot校正的潜在向量可用于在学习步骤中更新深度潜在变量的模型参数。
[0100]
需要注意的是,虽然方法2示出了根据adam方法的更新过程,其中β1=0.9且β2=0.5,参数β1(一阶矩估计的指数衰减率)和β2(二阶矩估计的指数衰减率)可以是其他值并且可以使用其他方法。这样的变化仍应在本专利文件的范围内。
[0101]
2.最优传输
[0102]
给定从q
θ
(zk)采样的潜在代码,即以及来自先验的随机生成的样本从到的一对一映射是通过最优传输计算的。特别地,在一个或多个实施例中,成本函数被设置为平方欧几里得距离因为它具有很好的几何意义,然后解决以下分配问题:
[0103][0104]
其中
[0105]
根据线性规划理论,π的每一行/列中只有一个非零元素。实际上,所有非零元素都应该等于1/n。因此,从到{zj}的映射可以定义为:如果π
ij
≠0,则当n很大时,直接用线性规划解决上述问题将是有问题的,因为计算复杂度非常高o(n
2.5
)。类似地,分配问题的经典匈牙利算法由于高计算复杂度o(n3),不能用于解决这个问题。用近似的ot求解器,例如sinkhorn算法也无法解决上述问题,因为这些求解器往往会给出密集的传输计划,因此不可能恢复ot映射。此外,近似算法不适用于n>20,000的大规模问题。因此,使用等式(11)的对偶问题。在一个或多个实施例中,半离散ot的原始对偶公式可以扩展到离散设置中的以下最小化问题:
[0106][0107]
上述问题是凸性的,因为它是n个超平面之和的最大值。因此,它可以通过梯度下降优化来解决。梯度由计算,其中#jj是jj中的元素个数。假设h
*
是e(h)的最优解,则h=h
*
(c,c,...,c)
t
也是最优解。为了省略移动,定义为利用梯度信息,能量e(h)可以通过adam梯度下降算法最小化。
[0108]
由于等式(12)是分配问题的对偶,在最优解h
*
下,很容易通过
从到{zj}重构一对一ot映射。在优化过程中,当梯度的范数小于预定阈值ε时,过程停止。理想情况下,如果ε=0,则映射t变为单射和满射,每个jj只包括一个元素,即对应的i。在这种情况下,ot映射t是明确定义的。实际上,ε通常设置为ε>0,因此t既不是单射也不是满射。在这种情况下,对于一些zjs,可能有一个或多个相应的而对于其他一些zjs,相应的可能不存在。为了消除歧义并重建一对一映射,有必要处理将为空或包括一个或多个元素的集合jj。因此近似的ot映射由下面给出:(1)如果jj中只有一个元素,即i,则(2)当jj包括多于一个元素时,随机选择i∈jj,并放弃其他元素,然后定义(3)将放弃的和与空jj对应的zjs分别从的域和范围中去除。以这种方式,可以建立新的近似于ot映射t的内射和满射映射
[0109]
应当注意,在所公开的ot方法的实施例中,先验分布不限于高斯分布。只要易于采样,实际上可以选择任何先验分布。此外,解决等式(12)中非光滑对偶问题的计算复杂度为在训练具有大量参数的复杂神经网络的背景下,用于优化ot问题的时间可以忽略不计。最后,从中移除的样本数不应大于nε。在一个或多个实验中,ε通常设置为ε=0.05。使用如此小的ε,可以获得ot映射的良好近似。
[0110]
e.实验结果
[0111]
需要说明的是,这些实验和结果是为了说明而提供的,是在特定条件下使用一个或多个特定实施例进行的;因此,这些实验及其结果均不应用于限制本专利文件的公开范围。
[0112]
在实验中,所公开模型的实施例就其是否可以(1)成功地校正由短期朗之万动力学推断的潜在向量的边缘分布q
θ
(zk),(2)学习从先验分布合成视觉上真实的图像的富有表现力的生成器,以及(3)成功执行异常检测进行测试。为了示出公开方法的性能,对各种数据集进行了实验。有关生成器架构的设计、模型超参数的选择以及每个数据集的优化方法的详细信息,请参见补充材料。此外,为了研究不同超参数的影响,主要使用数据集a,因为它具有简单性和代表性。为了量化模型的性能,采用均方误差(mse)和frechet起始距离(fid)得分来衡量重建和生成的图像的质量。fid分数是用于评估生成图像质量的指标。
[0113]
数据集:在一个或多个实验中使用各种数据集进行训练和/或测试。为了快速收敛,随机选择了一些样本数据,诸如图像数据。所有训练图像都被调整大小并缩放到[-1,1]的范围。
[0114]
模型架构:模型的架构如表1中所示,其中数据集a、数据集b和数据集c的潜在维度数分别设置为30、64、64。
[0115]
表格1.用于不同数据集的生成器的架构
[0116][0117]
优化:生成器的参数用xavier常态初始化,然后用adam优化器优化,其中β1=0.5和β2=0.99。对于所有实验,批量大小设置为2,000。在方法1中,l和k均设置为50。对于数据集a,超参数α设置为0.5,对于数据集b和数据集c,超参数α设置为0.3。设置数据集a、b和c的步长s分别为0.3、3.0、3.0。对于所有模型,∑也设置为σ=0.3。
[0118]
计算成本:由于短期mcmc和最优传输的参与,需要考虑整个流水线的运行时间。这里以包括多个大小为32
×
32
×
3的图像的数据集b为例。所公开模型的实施例在两个nvidia titanx gpu上进行训练。对于每次迭代,k=30的推断步骤大约需要124分钟,最优传输的校正步骤大约需要10分钟,l2=2的学习步骤大约需要5分钟。一般地,模型需要运行10~15次迭代,大约会消耗1天。
[0119]
1.潜在空间分析
[0120]
为了验证所提出的方法确实校正了潜在变量的短期边缘分布q
θ
(zk),数据集a的类别“0”和“1”被拾取。从这些类中,为了更好的可视化,将潜在空间维度设置为2来学习本公开模型的实施例。图8示出了公开模型的实施例在不同迭代中q
θ
(zk)的演变。在图中,迭代表示ot校正的次数。从图8可以清楚地看出,q
θ
(zk)由于ot校正逐渐向先验分布移动,并最终与
之匹配。图9分别示出了由vae模型、abp模型和所公开模型的实施例(示出为“当前”)推断的潜在向量的比较。由vae和abp模型推断出的潜在向量的分布与先验(高斯)分布相距很远,而公开的模型的边缘分布q
θ
(zk)看起来更接近于高斯分布。
[0121]
2.图像建模
[0122]
在一个或多个实验中,评估了重建和生成的图像的质量。利用良好学习的模型,q
θ
(zk)的边缘分布应该很好地匹配先验分布。在这种情况下,生成器将是从先验高斯分布到图像分布的概率变换,并且可以通过i=g
θ
(z)与从先验分布采样的潜在向量z来合成高质量的图像。此外,该模型可能对重建有用。在下文中,将公开模型的实施例与vae、其变体两阶段vae(2svae)和正则化自动编码器(regularized autoencoder,rae)进行比较。还与abp模型及其变体短期推断(sri)进行了比较,其生成器具有多层潜在变量。最后一个用于比较的模型是基于潜在空间能量的模型(1atent space energy-based model,lebm),它使用基于能量的短期mcmc来推断每个观察到的图像的潜在变量。
[0123]
给定具有从给定先验分布中采样的潜在向量的重构和生成图像,很明显生成的图像是真实的并且与训练数据集中的真实图像具有可比性。在表2中,mse用于测试重建图像的质量,并使用fid分数来量化生成图像的质量。从表中发现,公开方法的实施例(示出为“当前”列)在重建和生成任务中优于其他方法。
[0124]
表2.不同数据集的比较结果。mse和fid(越小越好)分别用于测试重建和生成图像的质量。
[0125][0126]
表3.auprc分数(越大越好)用于数据集a上的无监督异常检测。公开模型的实施例的结果针对方差在10次实验中取平均值。
[0127][0128]
3.异常检测
[0129]
异常检测是可以帮助评估公开模型的实施例的另一任务。使用从正常数据中学习良好的模型,可以通过以下检测异常数据:首先通过短期朗之万动力学从条件分布q
θ
(zk|i)中采样给定测试图像i的潜在代码z,再计算等式(5)中联合概率log p
θ
(i,z)的对数。基于此理论,正常图像的联合概率应该高,异常图像的联合概率应该低。
[0130]
在接下来的实验中,数据集a中的一个类被视为异常类,其他类则正常。该模型仅使用正常数据进行训练,然后使用正常和异常数据进行测试。为了评估性能,使用log p
θ
(i,z)作为决策函数来计算精确召回率曲线下面积(area under precision-recall curve,auprc)。在测试阶段,每个实验运行10次以获得均值和方差。在表3中,将公开方法的实施例(示出为“当前”列)与此任务中的相关模型进行了比较,包括vae、meg、bigan-σ、lebm和abp模型,其可以视为特殊情况无需ot校准。从表中可以看出,测试的方法实施例可以得到比其他方法更好的结果。
[0131]
4.潜在维度数量的影响
[0132]
本小节示出了相同架构下潜在空间的维数的影响。数据集b用于潜在空间的不同数量的维度,例如,分别为20、40和64。如表4所示,随着潜在维度的增加,在重建和生成方面都可以获得更好的结果。
[0133]
表4.提出的方法在数据集b上的性能,架构相同但潜在维度数不同。(对于mse和fid,越小越好。)
[0134][0135]
5.消融研究
[0136]
本小节探讨了提出的模型在等式(10)中引入的参数α的不同值、朗之万动力学的不同步长(等式(7)的s)、不同的朗之万步数(等式(7)中的k)以及学习步骤的不同迭代次数的情况下的性能,学习步骤旨在使用配对数据最大化等式(5)中的联合概率。
[0137]
α的影响。首先,研究等式(10)中α的影响,结果如图10a、10b和10c中所示。图10a示出了从到{zj}的ot成本,它作为q
θ
(zk)通过短期朗之万动力学与先验分布p(z)之间的距离。很明显,较大的α可以更快地将边缘分布q
θ
(zk)拉向先验分布。图10b建议:为了获得更小的mse损失,最好选择更小的α。如图10c中所示,使用中等值α获得最优fid,即α=0.5。因此,为了平衡ot成本、mse损失和fid,在以下实验中设置α=0.5。从曲线中还发现,随着算法的进行,边缘分布q
θ
(zk)越来越接近先验分布p0(z),重建的图像和生成的图像的质量也提高了。
[0138]
朗之万步长的影响。表5示出了不同朗之万步长(等式(7)中的s)下公开模型的实施例的性能。表中,“之前”表示模型在ot校正前使用,“之后”表示在ot校正后使用经训练的模型。s小,mse损失就会很小,但是fid比较大,这意味着生成的图像质量不是很好。当s很大时,例如最后一列中s=6e-2
,mse损失和fid都很大,这意味着甚至无法获得高质量的重建图像。在这种情况下,模型实际上并不能很好地收敛。只有通过适当的朗之万步长(在此实验中,s=3e-2
),才能在mse和fid之间获得良好的平衡,以满足重建和生成的结果。
[0139]
表5.朗之万动力学步长的影响。
[0140][0141]
朗之万步数的影响。等式(7)中的朗之万步数k是影响所提出的方法的性能的另一个关键因素。理论上,较大的k将导致更收敛的mcmc推断,从而有助于获得更准确的潜在变量。为了证明这一点,将k分别设为k=30,50,100,且其他参数保持不变。结果如表6中所示。事实上,较大的k会导致更好的结果。但是,较大的k也会线性增加整个流水线的运行时间。因此,为了在运行时间和性能之间取得良好的平衡,需要为不同的数据集选择合适的k。
[0142]
表6.朗之万步数k的影响。
[0143][0144]
学习步骤内迭代次数的影响。在方法1中,用l2表示的梯度上升的几次迭代实际上是在学习步骤中运行的,以通过配对数据最大化等式(5)中的联合概率。结果如表7中所示。从表中可以看出,通过增加l2,图像重建和生成可以获得更好的性能。
[0145]
表7.学习迭代次数的影响。
[0146][0147]
g.一些结论
[0148]
本文件公开了在训练深度潜在变量模型时使用ot来校正基于mcmc的短期推断的偏差的实施例。特别地,在一个或多个实施例中,通过此分布与先验分布之间的ot映射逐步校正短期朗之万动力学的潜在变量的边缘分布。这样,推断出的潜在向量的分布最终可以收敛到先验分布,从而提高后续参数学习的准确性。实验结果表明,所公开的训练方法实施例在图像重建、图像生成和异常检测等任务上的表现优于abp和vae模型。
[0149]
h.计算系统实施例
[0150]
在一个或多个实施例中,本专利文件的方面可以针对、可以包括或者可以在一个或多个信息处理系统(或计算系统)上实现。信息处理系统/计算系统可以包括任何工具或工具集合,可操作用于计算、推算、确定、分类、处理、发送、接收、检索、发起、路由、切换、存储、显示、通信、显示、检测、记录、复制、处理或利用任何形式的信息、情报或数据。例如,计算系统可以是或可以包括个人计算机(例如笔记本电脑)、平板计算机、移动设备(例如个人数字助理(pda)、智能手机、平板手机、平板电脑等)、智能手表、服务器(例如,刀片式服务器或机架式服务器)、网络存储设备、相机或任何其他合适的设备,并且可能在大小、形状、性能、功能和价格方面有所不同。计算系统可以包括随机存取存储器(ram)、诸如中央处理单元(cpu)或硬件或软件控制逻辑的一个或多个处理资源、只读存储器(rom)和/或其他类型的存储器。计算系统的附加组件可以包括一个或多个磁盘驱动器(例如,硬盘驱动器,固态驱动器,或两者兼有)、一个或多个用于与外部设备以及各种输入和输出(i/o)设备通信的网络端口,诸如键盘、鼠标、触摸屏、触控笔、麦克风、相机、触控板、显示器等。计算系统还可以包括一个或多个总线,该总线可操作以在各种硬件组件之间传输通信。
[0151]
图11描绘了根据本公开的实施例的信息处理系统(或计算系统)的简化框图。应当理解,系统1100所示的功能可以操作以支持计算系统的各种实施例——尽管应当理解计算系统可以被不同地配置并且包括不同的组件,包括比图11所描绘的更少或更多的组件。
[0152]
如图11中所示,计算系统1100包括一个或多个cpu 1101,其提供计算资源并控制计算机。cpu 1101可以用微处理器等来实现,并且还可以包括一个或多个图形处理单元(gpu)1102和/或用于数学计算的浮点协处理器。在一个或多个实施例中,一个或多个gpu 1102可以并入显示器控制器1109内,诸如一个或多个图形卡的一部分。系统1100还可以包括系统存储器1119,其可以包括ram、rom或两者。
[0153]
还可以提供多个控制器和外围设备,如图11所示。输入控制器1103表示到各种输入设备1104的接口。计算系统1100还可以包括存储控制器1107,用于与一个或多个存储设备1108接口,每个存储设备包括诸如磁带或磁盘之类的存储介质,或可用于记录操作系统、实用程序和应用的指令程序的光学介质,其可包括实现本公开的各个方面的程序的实施例。根据本公开存储设备1108还可用于存储已处理数据或要处理的数据。系统1100还可以包括显示器控制器1109,用于向显示器设备1111提供接口,显示器设备1111可以是阴极射
线管(crt)显示器、薄膜晶体管(tft)显示器、有机发光二极管、电致发光面板、等离子面板或任何其他类型的显示器。计算系统1100还可包括用于一个或多个外围设备1106的一个或多个外围控制器或接口1105。外围设备的示例可包括一个或多个打印机、扫描仪、输入设备、输出设备、传感器等。通信控制器1114可以与一个或多个通信设备1115接口,这使得系统1100能够通过各种网络中的任何一个连接到远程设备,包括互联网、云资源(例如,以太网云、以太网光纤通道(fcoe)/数据中心桥接(dcb)云等)、局域网(lan)、广域网(wan)、存储区域网络(san)或通过任何合适的电磁载波信号,包括红外信号。如所描绘的实施例中所示,计算系统1100包括一个或多个风扇或风扇托盘1118和一个或多个冷却子系统控制器或多个控制器1117,其监控系统1100(或其组件)的热温度并操作风扇/风扇托盘1118以帮助调节温度。
[0154]
在所示系统中,所有主要系统组件都可以连接到总线516,该总线可以表示多于一个的物理总线。然而,各种系统组件可能彼此物理接近,也可能不接近。例如,输入数据和/或输出数据可以从一个物理位置远程传输到另一个物理位置。此外,可以通过网络从远程位置(例如,服务器)访问实现本公开的各个方面的程序。这样的数据和/或程序可以通过多种机器可读介质中的任何一种来传递,包括例如:诸如硬盘、软盘和磁带之类的磁介质;诸如光盘(cd)和全息设备的光学介质;磁光介质;以及专门配置为存储或存储和执行程序代码的硬件设备,诸如专用集成电路(asic)、可编程逻辑器件(pld)、闪存设备、其他非易失性存储器(nvm)设备(诸如基于3d xpoint的设备)以及rom和ram设备。
[0155]
本公开的各方面可以被编码在一个或多个非暂时性计算机可读介质上,具有用于一个或多个处理器或处理单元的指令以导致执行步骤。应当注意,一个或多个非暂时性计算机可读介质应当包括易失性和/或非易失性存储器。应注意,替代实现是可能的,包括硬件实现或软件/硬件实现。硬件实现的功能可以使用asic、可编程阵列、数字信号处理电路等来实现。因此,任何权利要求中的“手段”术语旨在涵盖软件和硬件实现。类似地,本文使用的术语“计算机可读介质”包括其上包含有指令程序的软件和/或硬件,或其组合。考虑到这些实现替代方案,应当理解,附图和随附的描述提供了本领域技术人员编写程序代码(即软件)和/或组装电路(即硬件)所需的功能信息以执行所需的处理。
[0156]
应当注意,本公开的实施例进一步可以涉及具有在其上具有用于执行各种计算机实现的操作的计算机代码的非暂时性有形计算机可读介质的计算机产品。介质和计算机代码可以是为了本公开的目的而专门设计和构造的那些,或者它们可以是相关领域的技术人员已知或可获得的类型。有形计算机可读介质的示例包括例如:诸如硬盘、软盘和磁带之类的磁性介质;诸如cd和全息设备的光学介质;磁光介质;以及专门配置为存储或存储和执行程序代码的硬件设备,诸如asic、pld、闪存设备、其他非易失性存储器(nvm)设备(诸如基于3dxpoint的设备)以及rom和ram设备。计算机代码的示例包括诸如由编译器产生的机器代码,以及包含由计算机使用解释器执行的更高级别代码的文件。本公开的实施例可以全部或部分地实现为机器可执行指令,这些指令可以在由处理设备执行的程序模块中。程序模块的示例包括库、程序、例程、对象、组件和数据结构。在分布式计算环境中,程序模块可以物理地位于本地、远程或两者的设置中。
[0157]
本领域技术人员将认识到没有计算系统或编程语言对于本公开的实践是关键的。本领域技术人员还将认识到,上述多个元件可以在物理上和/或功能上分离成模块和/或子
模块或组合在一起。
[0158]
本领域技术人员将理解,前述示例和实施例是示例性的并且不限制本公开的范围。本领域技术人员在阅读说明书和研究附图后,对本领域技术人员显而易见的所有排列、增强、等价物、组合和改进旨在包括在本公开的真实精神和范围内。还应注意,任何权利要求的元件可以不同地布置,包括具有多个依赖关系、配置和组合。
再多了解一些

本文用于创业者技术爱好者查询,仅供学习研究,如用于商业用途,请联系技术所有人。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献