一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法及装置

2022-05-27 01:40:35 来源:中国专利 TAG:


1.本发明属于计算机图像生成技术领域,尤其涉及一种通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法及装置。


背景技术:

2.真实图像数据集是计算机视觉领域中用于训练网络的必不可少的工具,大量的真实图像能够帮助网络很好地学习到有用的特征表示,进而在后续的任务中发挥良好的性能。然而在真实图像数据集的制作过程中,由于图像采集设备的差异,采集到的原始图像还需要进行尺寸调整、分辨率统一等对齐操作,这通常需要耗费巨大的人力成本,这是导致网络训练成本高的一个重要原因,此外,对于一些现存量低、采集难度大的样本,数据集的制作成本与难度都更高,样本量的不足也将导致网络训练效果不佳。生成图像是一种通过训练良好的生成模型产生的与训练集中的真实图像近似的图像,它可以由生成模型直接从随机噪声中映射而来,通过使用生成模型产生的大量生成图像来扩充现有数据集是解决上述问题的方法之一,生成模型通过对现有数据集中真实图像的学习,生成逼真且具有多样性的生成图像,可以极大地减少人工收集、处理数据所带来的数据集制作成本。
3.生成对抗网络是近年来大热的一种生成模型,它由goodfellow等人在文章“generative adversarial networks.”(neuralips,2014)中提出。在生成对抗网络中,生成器负责接收随机噪声并生成图像,判别器负责接收真实图像样本和生成器产生的样本,并判断所接收的样本是否为真实图像。在网络的训练过程中,生成器与判别器在相互对抗中不断得到优化。然而,生成对抗网络存在着判别器“灾难性遗忘”、训练过程不稳定的缺点,甚至会导致模式坍塌问题。目前的一个解决方向是通过引入额外的辅助任务来完成生成对抗网络的自监督学习,使判别器学习到更通用、更稳定的特征,以提升训练过程的稳定性。然而现有的辅助任务通常为单一任务,这容易导致网络学习到的特征带有较明显的任务偏向性,例如gidaris等人在文章“unsupervised representation learning by predicting image rotations.”(iclr,2018)中提出的旋转任务,通过对输入图像进行随机旋转操作,并要求网络判断出所接收的图像对应的旋转角度,这一任务虽然能够有效帮助网络学习图像的结构特征,但是图像的色彩、纹理等特征却因为对判断旋转角度这一任务的帮助不大而易被网络忽略。现有的针对生成对抗网络提出的自监督辅助任务都存在任务单一、覆盖的特征不够全面的问题,对于网络学习特征的过程的引导具有较强的偏向性,这不利于网络学习通用且稳定的特征,也会影响后续生成的图像的质量。


技术实现要素:

4.本发明的主要目的在于克服现有技术的缺点与不足,提供一种通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法及装置,通过设计并引入一个合理的复合任务,引导网络学习图像中更稳定、更通用的特征,从而提升网络的训练效果,进而提高最
终生成的图像的质量。
5.为了达到上述目的,本发明采用以下技术方案:
6.本发明一方面提供了通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法,包括下述步骤:
7.准备训练数据集,所述训练数据集包括原始图像数据以及拼接图像数据;所述原始图像数据用于对抗训练分支的训练过程,所述拼接图像数据用于自监督复合任务分支的训练过程;
8.设计三个子任务来组成一个复合任务,所述复合任务用于构建自监督复合任务分支并为模型的训练提供监督信息,所述三个子任务分别是旋转预测任务、位置预测任务和共有特征提取任务,所述所述旋转预测任务用于正确判断出每张拼接图像中包含的图像块对应的签,所述位置预测任务用于正确判断出每张拼接图像中包含的图像块对应的标签,所述共有特征提取任务用于首先正确判断出每个图像块属于哪一张原始图像,然后提取出同源图像块之间的共有特征;
9.搭建模型,分别构建对抗训练分支和自监督复合任务分支,所述对抗训练分支包含一个局部判别器和一个生成器,所述自监督复合任务分支包含一个带有三个输出头的分类器;
10.对搭建的模型进行训练,得到训练好的生成器网络;所述训练具体为:将原始数据集作为对抗训练分支的输入,拼接图像数据作为自监督复合任务分支的输入,对两个分支中的网络进行训练,所述训练过程中,自监督复合任务分支负责为对抗训练分支中的局部判别器提供监督信息;
11.将待处理的图像输入到训练好的生成器网络进行图像生成。
12.作为优选的技术方案,所述拼接图像数据准备具体如下:
13.对于一个批次内的原始图像,从每张图像的左上、右上、左下、右下区域裁剪出4个具有一定重叠率ω的图像块,重叠率ω等于相邻图像块之间重叠部分的边长与原图像边长的比值;
14.将得到的一个批次的图像块进行随机打乱,并进行一次旋转变换,旋转角度从集合r={0
°
,90
°
,180
°
,270
°
}中随机选择;
15.使用双线性插值法对得到的一个批次的图像块的尺寸进行调整,使图像块的边长等于原始图像边长的一半;
16.将得到的图像块以4个为一组进行拼接,得到一批与原始图像尺寸一致的拼接图像,由此,完成了拼接图像数据的制作。
17.作为优选的技术方案,所述三个子任务的具体设计如下:
18.在旋转预测任务中,对于得到的一批拼接图像,每张图像包含的4个图像块各自对应了一个旋转角度,通过这一旋转角度为每个图像块赋予一个伪标签lr,lr∈{0
°
,90
°
,180
°
,270
°
};
19.在位置预测任务中,对于得到的一批拼接图像,每张图像包含的4个图像块在其所属原始图像中各自对应了一个固定的区域位置,通过这一位置信息为每个图像块赋予一个伪标签l
l
,l
l
∈{左上,右上,左下,右下};
20.在共有特征提取任务中,对于得到的一批拼接图像,每张图像包含的4个图像块各
自对应一张原始图像,将属于同一张原始图像的4个图像块定义为同源图像块,将同源图像块之间具有较高相似度的特征定义为共有特征。
21.作为优选的技术方案,搭建模型的具体步骤为:
22.在对抗训练分支中,构建局部判别器local-d和生成器g,所述局部判别器的网络结构被分割为两个部分,并在二者中间加入了一个特征分块模块;第一部分接收原始图像作为输入并提取图像特征,特征分块模块负责将第一部分输出的图像特征处理成图像块特征,第二部分接收图像块特征作为输入并产生局部判别器的最终输出,局部判别器local-d的任务是正确判断出分块后的特征是来自真实图像还是生成图像,此分支的损失函数记为l
adv
,与原始生成对抗网络中提出的对抗损失一致,其具体表达式为:
[0023][0024]
其中,x是采样自原始数据集的真实图像,p
data(x)
是真实数据的分布,z是从先验分布中采样的随机噪声,p
z(z)
是先验分布,d为局部判别器,g为生成器;
[0025]
在自监督复合任务分支,使用与对抗训练分支中局部判别器一致的网络架构来搭建分类器c,分类器网络同样被分为两部分,其中第一部分的网络架构与local-d的第一部分相同,并且二者共享网络权重,第二部分包含了分类器的三个输出头部,其中两个头部均由一个全连接层构成,分别负责输出旋转预测任务和位置预测任务的结果,第三个头部由包含一个隐藏层的多层感知机构成,负责输出共有特征提取任务的结果,此分支的总损失函数记为l
ct

[0026]
作为优选的技术方案,对于自监督复合任务分支采用多个损失函数进行组合优化,该自监督复合任务分支的总损失函数l
ct
定义为:
[0027]
l
ct
=l
rot
l
loc
l
cfe
[0028]
其中,l
rot
、l
loc
、l
cfe
分别代表旋转预测任务损失、位置预测任务损失、共有特征提取任务损失;
[0029]
记拼接图像中各图像块的真实旋转标签为l
r_gt
,真实位置标签为l
l_gt
,分类器为图像块预测的旋转标签为lr,预测的位置标签为l
l
,共享特征提取任务中多层感知机输出的一组向量记为v1,v2,

,v
{k.k=n
×
4}
,使用二元交叉熵计算旋转预测与位置预测的损失,使用余弦相似度计算不同块特征间的相似度,三项任务损失的计算公式如下:
[0030]
l
rot
=crossentropy(lr,l
r_gt
)
[0031]
l
loc
=crossentropy(l
l
,l
l_gt
)
[0032][0033][0034]
其中,τ为温度系数,n为训练批大小,ci表示第i个的图像块的同源图像块对应下标的集合,i为指示函数,当满足判断条件时函数值为1,否则为0。
[0035]
作为优选的技术方案,模型训练的具体步骤为:
[0036]
在对抗训练分支中,生成器g与局部判别器local-d交替迭代训练,局部判别器的输入是采样自原始图像数据集的一批图像,其训练目标是正确判断出输入图像中某一区域
内的图像的真实性;生成器的输入是随机噪声,其训练目标是输出尽量真实的、能够骗过局部判别器的生成图像;
[0037]
在自监督复合任务分支中,分类器的输入是一批拼接图像,三个输出头部分别输出三个子任务的结果,该分支通过总损失函数l
ct
对网络进行训练;
[0038]
对抗训练分支和自监督复合任务分支同时训练,两个分支在训练过程中通过共享局部判别器local-d与三头分类器c的第一部分的网络权重来建立联系,模型的总损失函数定义为:
[0039][0040]
其中,l
adv
(g,d)为对抗训练损失,l
ct
(c)为自监督复合任务损失,在模型训练过程中,局部判别器local-d与生成器g交替更新,与三头分类器c同时更新。
[0041]
作为优选的技术方案,所述将待处理的图像输入到训练好的生成器网络进行图像生成,具体为:
[0042]
将随机噪声输入到训练好的生成器网络中,经过前向传播就可以得到与训练集图像近似的高质量生成图像。
[0043]
本发明另一方面提供了通过引入自监督复合任务训练生成对抗网络生成高质量图像的系统,应用于所述的通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法,包括数据集模块、复合任务模块、模型搭建模块、模型训练模块以及图像生成模块;
[0044]
所述数据集模块,用于准备训练数据集,所述训练数据集包括原始图像数据以及拼接图像数据;所述原始图像数据用于对抗训练分支的训练过程,所述拼接图像数据用于自监督复合任务分支的训练过程;
[0045]
所述复合任务模块,用于设计三个子任务来组成一个复合任务,所述复合任务用于构建自监督复合任务分支并为模型的训练提供监督信息,所述三个子任务分别是旋转预测任务、位置预测任务和共有特征提取任务,所述所述旋转预测任务用于正确判断出每张拼接图像中包含的图像块对应的签,所述位置预测任务用于正确判断出每张拼接图像中包含的图像块对应的标签,所述共有特征提取任务用于首先正确判断出每个图像块属于哪一张原始图像,然后提取出同源图像块之间的共有特征;
[0046]
所述模型搭建模块,用于分别构建对抗训练分支和自监督复合任务分支,所述对抗训练分支包含一个局部判别器和一个生成器,所述自监督复合任务分支包含一个带有三个输出头的分类器;
[0047]
所述模型训练模块,用于对搭建的模型进行训练,得到训练好的生成器网络;所述训练具体为:将原始数据集作为对抗训练分支的输入,拼接图像数据作为自监督复合任务分支的输入,对两个分支中的网络进行训练,所述训练过程中,自监督复合任务分支负责为对抗训练分支中的局部判别器提供监督信息;
[0048]
所述图像生成模块,用于将待处理的图像输入到训练好的生成器网络进行图像生成。
[0049]
本发明又一方面提供了一种电子设备,所述电子设备包括:
[0050]
至少一个处理器;以及,
[0051]
与所述至少一个处理器通信连接的存储器;其中,
[0052]
所述存储器存储有可被所述至少一个处理器执行的计算机程序指令,所述计算机
程序指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行所述的通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法。
[0053]
本发明再一方面提供了一种计算机可读存储介质,存储有程序,所述程序被处理器执行时,实现所述的通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法。
[0054]
本发明与现有技术相比,具有如下优点和有益效果:
[0055]
(1)、本发明针对生成对抗网络的自监督学习,提出了一种基于多层级信息的复合辅助任务,同时利用了图像内信息与图像间信息作为网络训练时的监督,稳定了训练过程,可以提高网络提取的特征的通用性,提升生成的图像的质量。
[0056]
(2)、本发明提出了局部判别器local-d,使网络在提取特征的过程中关注到更多的局部特征信息,并与自监督复合任务分支相配合,提升了整个模型自监督学习的效果。
[0057]
(3)、本发明提出的共有特征提取任务,将对比学习的思想引入到生成对抗网络的自监督学习中,在利用对比学习的优势提高了提取特征的质量的同时,将训练批大小保持在一个较小的数值,以较小的训练成本带来了较大的网络性能增益。
附图说明
[0058]
为了更清楚地说明本技术实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本技术的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
[0059]
图1为本发明实施例通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法的整体流程图。
[0060]
图2为本发明实施例图像处理模块的流程图;其中包括了图像数据的三种状态,状态1为一个批次内的图像数据的初始状态,状态2为每张图像均被裁剪为4个图像块后的状态,状态3为所有图像块经过随机打乱、随机旋转、拼接后得到一批拼接图像数据的状态。
[0061]
图3为本发明实施例网络模型的整体结构图。
[0062]
图4为本发明实施例生成器网络的结构图。
[0063]
图5为本发明实施例生成器网络中反卷积模块的结构图。
[0064]
图6为本发明实施例提出的局部判别器网络的结构图。
[0065]
图7为本发明实施例局部判别器中卷积模块的结构图。
[0066]
图8为本发明实施例特征分块模块的示意图。
[0067]
图9为本发明实施例通过引入自监督复合任务训练生成对抗网络生成高质量图像的系统的方框图。
[0068]
图10为本发明实施例电子设备的结构图。
具体实施方式
[0069]
为了使本技术领域的人员更好地理解本技术方案,下面将结合本技术实施例中的附图,对本技术实施例中的技术方案进行清楚、完整地描述。显然,所描述的实施例仅仅是本技术一部分实施例,而不是全部的实施例。基于本技术中的实施例,本领域技术人员在没
有做出创造性劳动前提下所获得的所有其他实施例,都属于本技术保护的范围。
[0070]
在本技术中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本技术的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本技术所描述的实施例可以与其它实施例相结合。
[0071]
请参阅图1,本实施例提供了一种通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法,包括下述步骤:
[0072]
s1、准备训练数据集,包括原始图像数据集和拼接图像数据集;
[0073]
进一步的,原始图像数据集用于作为对抗训练分支的输入,通过网络搜索引擎获取公开的cifar-10、celeba、imagenet32
×
32、stl-10等数据集作为本方法的原始图像数据集。
[0074]
进一步的,所述拼接图像数据集用于作为自监督复合任务分支的输入,通过对原始图像数据集进行一定的处理后得到,处理流程请参阅图2。首先,对于一个批次内的原始图像,从每张图像的左上、右上、左下、右下区域裁剪出4个具有一定边长重叠率ω=0.4的图像块,然后对它们进行随机打乱,并进行一次旋转变换,旋转角度从集合r={0
°
,90
°
,180
°
,270
°
}中随机选择,使用双线性插值法调整图像块的尺寸,使其边长等于原始图像边长的一半,最后将这些图像块以4个为一组进行拼接,得到一批与原始图像尺寸一致的拼接图像,即得到了所需的拼接图像数据集。
[0075]
s2、设计复合任务,所述复合任务用于为网络的训练过程构建有力的监督信号;
[0076]
进一步的,为了能够充分利用图像的信息,本实施例设计了一个包含三个子任务的复合任务,三个子任务分别为旋转预测任务、位置预测任务、共有特征提取任务,其中前两个任务是基于图像内信息进行设计的,第三个任务是基于图像间信息设计的。
[0077]
更进一步的,所述旋转预测任务利用图像块的旋转信息构造伪标签lr,lr∈{0
°
,90
°
,180
°
,270
°
},此任务要求网络正确判断出拼接图像中包含的4个图像块对应的旋转标签。网络只有较好地理解了图像块中的结构信息才能顺利完成此任务。
[0078]
更进一步的,位置预测任务利用图像块在原始图像中的位置信息构造伪标签l
l
,l
l
∈{左上,右上,左下,右下},此任务要求网络正确判断出拼接图像中包含的4个图像块对应的原始位置标签。具体来说,在制作拼接图像数据集时,通过在原始图像的左上、右上、左下、右下四个区域进行裁剪得到图像块,因此每个图像块都对应了一个在原始图像中的位置标签,网络只有较好地理解图像块以及原始图像的结构特征才能正确预测图像块的位置标签。
[0079]
更进一步的,所述共有特征提取任务对不同图像间的信息进行利用,要求网络提取出同源图像块之间的共有特征。定义从同一张原始图像裁剪出的4个图像块为同源图像块,它们之间具有高相似度的特征为共有特征,此任务的原理在于:4个同源图像块共同组成了一张语义完整的原始图像,且它们之间存在一定的重叠区域,因此同源图像块之间是存在一些共有特征的,这些共有特征能够帮助网络正确区分出同源图像块与非同源图像块,因此在完成此任务的过程中,网络对于共有特征的提取能力会不断得到增强,这有助于网络学习到图像中具有代表性的特征。
[0080]
s3、搭建网络模型;
[0081]
请参阅图3,整个网络模型构建过程包括四部分:使用resnet-50网络实现生成器,使用resnet-50网络实现局部判别器,构建三头分类器,设计各分支的损失函数以及网络的总损失函数。
[0082]
进一步的,搭建网络模型的具体步骤如下:
[0083]
s31、构建生成器网络:
[0084]
请参阅图4,构建的生成器网络的结构;对于输入的随机噪声,首先通过一个输出通道数为4096的全连接层进行学习,并调整所得向量的大小为4
×4×
256;然后经过三个连续的反卷积模块;之后进行一次批归一化、relu激活和3
×
3卷积(输出通道数为3,步长为1),最后经过sigmoid激活后,得到最终输出,大小为32
×
32
×
3。
[0085]
反卷积模块的结构请参阅图5,其包含的具体操作为:对输入x进行批归一化、relu激活、上采样、3
×
3卷积、relu激活、3
×
3卷积,得到x1;对输入x进行进行上采样、1
×
1卷积,得到x2,最终反卷积模块的输出为x1 x2。反卷积模块中涉及的卷积操作的输出通道数均为256,步长均为1。
[0086]
s32、构建局部判别器网络:
[0087]
请参阅图6提出的局部判别器网络的结构;所述局部判别器网络被分为两个部分,第一部分由多个卷积模块和relu激活操作构成,负责输出中间层特征;第二部分由特征分块模块和全连接层构成,负责处理中间层特征并产生最终的输出。输入图像的初始大小为32
×
32
×
3,首先经过第一个卷积模块,其中包含如下步骤:对输入x进行3
×
3卷积、relu激活、3
×
3卷积、1
×
1卷积(步长为2),得到x1,对输入x进行1
×
1卷积(步长为2)、1
×
1卷积(步长为1),得到x2,输出为x1 x2;之后经过三个连续的卷积模块;对卷积模块的输出进行一次relu激活后,将所得中间层特征输入到特征分块模块中进行处理,并输出分块特征;最后将分块特征输入到一个输出通道数为1的全连接层,得到最终输出结果。
[0088]
进一步的,三个连续的卷积模块的结构请参阅图7,具体包括如下步骤:对输入x进行重复两次的relu激活和3
×
3卷积,一次1
×
1卷积(步长为2),得到x1,对输入x进行1
×
1卷积(步长为1),1
×
1卷积(步长为2),得到x2,最终卷积模块的输出为x1 x2。卷积模块中涉及的所有卷积操作的输出通道数均为128,所有3
×
3卷积的步长均为1。需要注意的是,在后两个卷积模块中,舍去了步长为2的1
×
1卷积操作,因此最终经过relu激活后得到的中间层特征的大小为8
×8×
128。
[0089]
特征分块模块的结构如图8所示,对于输入的中间层特征f,首先进行平均池化(平均池化核大小为2
×
2,步长为2),然后对得到的特征在长宽维度进行2
×
2分块操作,得到分块特征f
p
,其大小为2
×2×
128,数量是输入的中间层特征f数量的四倍。
[0090]
s33、构建三头分类器网络:
[0091]
三头分类器网络的结构同样被分为两个部分,其中第一部分的结构与局部判别器的第一部分相同,并共享权重,同步更新参数,这一部分主要负责对输入的拼接图像提取中间层特征;第二部分包含了特征分块模块和三个输出头部。特征分块模块的结构请参阅图8,它负责将中间层特征处理成分块特征,并输出给三个头部。在三个输出头部中,有两个头部均由一个简单的全连接层构成,输出通道数均为4,分别负责预测图像块的旋转角度和原始位置标签;第三个头部由包含一个隐藏层的多层感知机构成,负责完成共有特征提取任务,隐藏层的输出通道数与输入的分块特征的通道数一致,使用relu激活函数,最后输出层
的输出通道数为64。使用多层感知机对特征进行一次额外的非线性变换的方法最早在chen等人的文章“a simple framework for contrastive learning of visual representations.”(icml,2020)中提出,它能够帮助网络学习到质量更高的特征表示。
[0092]
s4、设计损失函数:
[0093]
本实施例所提出模型的总损失函数l
final
由两个分支的损失函数组成,即:对抗训练分支的损失函数l
adv
和自监督复合任务分支的损失函数l
ct
。模型总损失函数的表达式如下:
[0094][0095]
其中,l
adv
的计算方式与经典生成对抗网络中的损失函数一致,该损失最早由goodfellow等人在文章“generative adversarial networks.”(neuralips,2014)中提出。l
adv
的具体表达式如下:
[0096][0097]
其中,x是采样自原始数据集的真实图像,p
data(x)
是真实数据的分布,z是从先验分布中采样的随机噪声,p
z(z)
是先验分布,d为局部判别器,g为生成器。
[0098]
l
ct
由三个子任务的损失组成,即:旋转预测任务损失l
rot
、位置预测任务损失l
loc
、共有特征提取任务损失l
cfe
,其定义为:
[0099]
l
ct
=l
rot
l
loc
l
cfe
[0100]
l
rot
=crossentropy(lr,l
r_gt
)
[0101]
l
loc
=crossentropy(l
l
,l
l_gt
)
[0102][0103][0104]
其中,l
r_gt
与l
l_gt
分别表示拼接图像中各图像块的真实旋转标签与真实位置标签,lr与l
l
分别表示分类器为图像块预测的旋转标签与位置标签,crossentropy(
·
)为二元交叉熵函数,v1,v2,

,v
{k.k=n
×
4}
表示一组由多层感知机输出的向量,sim(
·
)为余弦相似度函数,i为指示函数,当满足判断条件时函数值为1,否则为0,τ为温度系数,默认取值为0.3,n为训练批大小,默认取值为64,ci表示第i个的图像块的同源图像块对应下标的集合。
[0105]
s4、对构建的模型进行训练;
[0106]
将原始图像数据集输入到生成器和局部判别器进行交替训练,同时将拼接图像数据集输入到三头分类器中进行训练,使用自适应学习率的adam算法进行损失的优化,训练批大小为64,共训练30万次。
[0107]
进一步的,在训练过程中,对于对抗训练分支,将随机噪声输入到生成器中,得到生成图像,将生成图像输入到判别器中时,计算对抗训练损失并反向传播梯度,生成器调整参数进行优化,使其趋向于生成与真实图像更接近的生成图像;当将从原始数据集中采样的一批次真实图像输入到局部判别器中时,计算对抗训练损失并反向传播梯度,判别器调整参数进行优化,使其趋向于提高对生成图像与真实图像的区分能力;对于自监督复合任
务分支,将一批次拼接图像输入到三头分类器中时,分类器将同时进行复合任务中的三个子任务,计算复合任务的总损失并反向传播梯度,分类器调整参数进行优化,使其趋向于提取出更全面、更通用的图像特征以更好地完成复合任务,分类器与局部判别器参数共享,因此在训练过程中与局部判别器同时更新,同时也使局部判别器获得复合任务带来的特征提取能力。在模型的训练过程中,生成器与局部判别器形成一种互相对抗的关系,即生成器致力于生成更能够欺骗判别器的逼真图像,而判别器则致力于不断提高区分真实图像与生成图像的能力,当二者之间的对抗达到平衡时,生成器的图像生成能力与判别器的特征提取能力均达到一个较高的水平。当训练完成后,保存整个模型的参数。
[0108]
s5、图像生成;
[0109]
通过移除对抗训练分支中的局部判别器以及整个自监督复合任务分支,仅使用生成器网络即可完成图像生成:输入随机噪声至生成器中,经过前向传播即可得到与原始数据集近似的质量较高的图像。
[0110]
本实施例提供的通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法,对图像内信息和图像间信息同时进行利用,构建了一个包含三个子任务的复合任务,引导网络学习图像中更稳定、更通用的特征,同时还构建了局部判别器来提高网络提取图像局部信息的能力,可显著提升网络的训练效果,提高最终生成的图像的质量。
[0111]
需要说明的是,对于前述的各方法实施例,为了简便描述,将其都表述为一系列的动作组合,但是本领域技术人员应该知悉,本发明并不受所描述的动作顺序的限制,因为依据本发明,某些步骤可以采用其它顺序或者同时进行。
[0112]
基于与上述实施例中的通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法相同的思想,本发明还提供了通过引入自监督复合任务训练生成对抗网络生成高质量图像的系统,该系统可用于执行上述通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法。为了便于说明,通过引入自监督复合任务训练生成对抗网络生成高质量图像的系统实施例的结构示意图中,仅仅示出了与本发明实施例相关的部分,本领域技术人员可以理解,图示结构并不构成对装置的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
[0113]
请参阅图9,在本技术的另一个实施例中,提供了一种通过引入自监督复合任务训练生成对抗网络生成高质量图像的系统100,该系统包括包括数据集模块101、复合任务模块102、模型搭建模块103、模型训练模块104以及图像生成模块105;
[0114]
所述数据集模块101,用于准备训练数据集,所述训练数据集包括原始图像数据以及拼接图像数据;所述原始图像数据用于对抗训练分支的训练过程,所述拼接图像数据用于自监督复合任务分支的训练过程;
[0115]
所述复合任务模块102,用于设计三个子任务来组成一个复合任务,所述复合任务用于构建自监督复合任务分支并为模型的训练提供监督信息,所述三个子任务分别是旋转预测任务、位置预测任务和共有特征提取任务,所述所述旋转预测任务用于正确判断出每张拼接图像中包含的图像块对应的签,所述位置预测任务用于正确判断出每张拼接图像中包含的图像块对应的标签,所述共有特征提取任务用于首先正确判断出每个图像块属于哪一张原始图像,然后提取出同源图像块之间的共有特征;
[0116]
所述模型搭建模块103,用于分别构建对抗训练分支和自监督复合任务分支,所述
对抗训练分支包含一个局部判别器和一个生成器,所述自监督复合任务分支包含一个带有三个输出头的分类器;
[0117]
所述模型训练模块104,用于对搭建的模型进行训练,得到训练好的生成器网络;所述训练具体为:将原始数据集作为对抗训练分支的输入,拼接图像数据作为自监督复合任务分支的输入,对两个分支中的网络进行训练,所述训练过程中,自监督复合任务分支负责为对抗训练分支中的局部判别器提供监督信息;
[0118]
所述图像生成模块105,用于将待处理的图像输入到训练好的生成器网络进行图像生成。
[0119]
需要说明的是,本发明的通过引入自监督复合任务训练生成对抗网络生成高质量图像的系统与本发明的通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法一一对应,在上述通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法的实施例阐述的技术特征及其有益效果均适用于通过引入自监督复合任务训练生成对抗网络生成高质量图像的的实施例中,具体内容可参见本发明方法实施例中的叙述,此处不再赘述,特此声明。
[0120]
此外,上述实施例的通过引入自监督复合任务训练生成对抗网络生成高质量图像的系统的实施方式中,各程序模块的逻辑划分仅是举例说明,实际应用中可以根据需要,例如出于相应硬件的配置要求或者软件的实现的便利考虑,将上述功能分配由不同的程序模块完成,即将所述通过引入自监督复合任务训练生成对抗网络生成高质量图像的系统的内部结构划分成不同的程序模块,以完成以上描述的全部或者部分功能。
[0121]
请参阅图10,在一个实施例中,提供了一种实现通过引入自监督复合任务训练生成对抗网络生成高质量图像的方法的电子设备,所述电子设备200可以包括第一处理器201、第一存储器202和总线,还可以包括存储在所述第一存储器202中并可在所述第一处理器201上运行的计算机程序,如自监督复合任务训练生成对抗网络生成高质量图像程序203。
[0122]
其中,所述第一存储器202至少包括一种类型的可读存储介质,所述可读存储介质包括闪存、移动硬盘、多媒体卡、卡型存储器(例如:sd或dx存储器等)、磁性存储器、磁盘、光盘等。所述第一存储器202在一些实施例中可以是电子设备200的内部存储单元,例如该电子设备200的移动硬盘。所述第一存储器202在另一些实施例中也可以是电子设备200的外部存储设备,例如电子设备200上配备的插接式移动硬盘、智能存储卡(smart media card,smc)、安全数字(securedigital,sd)卡、闪存卡(flash card)等。进一步地,所述第一存储器202还可以既包括电子设备200的内部存储单元也包括外部存储设备。所述第一存储器202不仅可以用于存储安装于电子设备200的应用软件及各类数据,例如自监督复合任务训练生成对抗网络生成高质量图像程序203的代码等,还可以用于暂时地存储已经输出或者将要输出的数据。
[0123]
所述第一处理器201在一些实施例中可以由集成电路组成,例如可以由单个封装的集成电路所组成,也可以是由多个相同功能或不同功能封装的集成电路所组成,包括一个或者多个中央处理器(central processing unit,cpu)、微处理器、数字处理芯片、图形处理器及各种控制芯片的组合等。所述第一处理器201是所述电子设备的控制核心(control unit),利用各种接口和线路连接整个电子设备的各个部件,通过运行或执行存
储在所述第一存储器202内的程序或者模块(例如联邦学习防御程序等),以及调用存储在所述第一存储器202内的数据,以执行电子设备200的各种功能和处理数据。
[0124]
图10仅示出了具有部件的电子设备,本领域技术人员可以理解的是,图3示出的结构并不构成对所述电子设备200的限定,可以包括比图示更少或者更多的部件,或者组合某些部件,或者不同的部件布置。
[0125]
所述电子设备200中的所述第一存储器202存储的自监督复合任务训练生成对抗网络生成高质量图像程序203是多个指令的组合,在所述第一处理器201中运行时,可以实现:
[0126]
准备训练数据集,所述训练数据集包括原始图像数据以及拼接图像数据;所述原始图像数据用于对抗训练分支的训练过程,所述拼接图像数据用于自监督复合任务分支的训练过程;
[0127]
设计三个子任务来组成一个复合任务,所述复合任务用于构建自监督复合任务分支并为模型的训练提供监督信息,所述三个子任务分别是旋转预测任务、位置预测任务和共有特征提取任务,所述所述旋转预测任务用于正确判断出每张拼接图像中包含的图像块对应的签,所述位置预测任务用于正确判断出每张拼接图像中包含的图像块对应的标签,所述共有特征提取任务用于首先正确判断出每个图像块属于哪一张原始图像,然后提取出同源图像块之间的共有特征;
[0128]
搭建模型,分别构建对抗训练分支和自监督复合任务分支,所述对抗训练分支包含一个局部判别器和一个生成器,所述自监督复合任务分支包含一个带有三个输出头的分类器;
[0129]
对搭建的模型进行训练,得到训练好的生成器网络;所述训练具体为:将原始数据集作为对抗训练分支的输入,拼接图像数据作为自监督复合任务分支的输入,对两个分支中的网络进行训练,所述训练过程中,自监督复合任务分支负责为对抗训练分支中的局部判别器提供监督信息;
[0130]
将待处理的图像输入到训练好的生成器网络进行图像生成。
[0131]
进一步地,所述电子设备200集成的模块/单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个非易失性计算机可读取存储介质中。所述计算机可读介质可以包括:能够携带所述计算机程序代码的任何实体或装置、记录介质、u盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(rom,read-only memory)。
[0132]
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的程序可存储于一非易失性计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,本技术所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(rom)、可编程rom(prom)、电可编程rom(eprom)、电可擦除可编程rom(eeprom)或闪存。易失性存储器可包括随机存取存储器(ram)或者外部高速缓冲存储器。作为说明而非局限,ram以多种形式可得,诸如静态ram(sram)、动态ram(dram)、同步dram(sdram)、双数据率sdram(ddrsdram)、增强型sdram(esdram)、同步链路(synchlink)dram(sldram)、存储器总线(rambus)直接ram(rdram)、直接存储器总线动态ram(drdram)、以及存储器总线动态ram(rdram)等。
[0133]
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
[0134]
上述实施例为本发明较佳的实施方式,但本发明的实施方式并不受上述实施例的限制,其他的任何未背离本发明的精神实质与原理下所作的改变、修饰、替代、组合、简化,均应为等效的置换方式,都包含在本发明的保护范围之内。
再多了解一些

本文用于企业家、创业者技术爱好者查询,结果仅供参考。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献