一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

图像分类模型的训练方法、装置及存储介质与流程

2023-02-19 11:56:06 来源:中国专利 TAG:


1.本发明涉及图像处理与深度学习领域,尤其涉及一种图像分类模型的训练方法、装置及存储介质。


背景技术:

2.临床上,皮肤病的早期诊断除了一般的视觉筛查外,基于皮肤影像的皮肤病诊断是最常使用的一种诊疗手段。当前皮肤病的诊断方案主要基于医生目测和计算机辅助诊断为准,相比其他皮肤影像,通过皮肤镜通过消除皮肤表面的反射,可以将更深层次的肉眼无法辨别的病变特征可视化,和肉眼检查相比,可以将诊断敏感性提高10%至30%,在一定程度上降低了活检率,但基于人工识别的皮肤镜图像分析不仅耗时、费力,其获得的诊断结果也易受医生经验等主观因素影响。
3.而借助于计算机辅助诊断(computer-aided diagnosis,cad)系统实现基于皮肤镜图像的皮肤病自动识别分类,可以提高诊断的效率和准确率,并进行精准细致的分类,使得医生可以根据不同病变的特殊临床表现有针对性地制定最佳治疗方案,但在医院本地部署cad系统时,只能获取医院本地的数据,无法获取患者在其他医院的就诊信息,模型训练数据少,对于患者历史记录获取不完全。此外,目前对于大多数皮肤病分类、诊断的深度学习模型与方法都需要患者病灶图片、病史信息等相关隐私信息,若在互联网上部署cad系统,则容易泄露患者的隐私数据,数据隐私性、安全性较低。


技术实现要素:

4.本发明实施例提供一种图像分类模型的训练方法、装置及存储介质。
5.本发明实施例技术方案是这样实现的:
6.本发明实施例提供一种图像分类模型的训练方法,所述方法包括:
7.使用第一训练数据训练图像分类模型;其中,所述第一训练数据至少包括:第一样本图像;
8.通过训练得到的所述图像分类模型识别所述第一训练数据得到识别结果;
9.根据所述识别结果,确定出训练得到的所述图像分类模型识别错误的第二训练数据;
10.将至少部分所述第二训练数据输入至所述图像分类模型,继续训练所述图像分类模型。
11.上述方案中,所述根据所述识别结果,确定出训练得到的所述图像分类模型识别错误的第二训练数据,包括:
12.根据所述训练得到的所述图像分类模型,确定所述第一样本图像对应的识别结果的损失值;
13.将识别结果为错误且损失值大于第一预设值的第一样本图像确定为所述第二训练数据。
14.上述方案中,所述第一训练数据还包括:第一样本图像对应的用户信息。
15.上述方案中,所述使用第一训练数据训练图像分类模型,包括:
16.根据所述用户信息生成用户标签;
17.从所述第一样本图像中截取出目标区域和用户标签,将所述目标区域和所述用户标签输入至所述图像分类模型。
18.上述方案中,所述方法由第一设备执行,所述第一训练数据包括:来源所述第一设备的第一类训练数据,和/或,源自第二设备的第二类训练数据;
19.所述使用第一训练数据训练图像分类模型,包括:
20.使用明文的第一类训练数据训练所述图像分类模型;
21.和/或,
22.使用密文的所述第二类训练数据训练所述图像分类模型;其中,所述第二设备为不同于所述第一设备的任意设备。
23.上述方案中,所述方法还包括:
24.从第三设备接收加密密钥,其中,所述第三设备为分别与所述第一设备和所述第二设备连接的中间设备;
25.向所述第三设备请求所述第二类训练数据;
26.所述方法还包括:
27.将所述第二类训练数据输入到所述图像分类模型后,得到所述图像分类模型输出的预测标签;
28.利用所述加密密钥加密所述预测标签,并将加密后的所述预测标签发送给所述第三设备,其中,所述加密后的所述预测标签,用于供所述第三设备转发给所述第二设备并为所述第一设备确定损失值;
29.接收所述第三设备返回的所述第二设备提供的损失值;
30.基于所述损失值,判断是否继续训练所述图像分类模型。
31.上述方案中,所述将至少部分所述第二训练数据输入至所述图像分类模型,继续训练所述图像分类模型,包括:
32.在一批所述第一样本数据的训练次数未达到预定次数且存在所述第二训练数据时,继续利用所述第二训练数据训练所述图像分类模型;
33.所述方法还包括:
34.在一批所述第一样本数据的训练次数达到所述预定次数时,结束使用本批所述第一样本数据训练所述图像分类模型。
35.本发明实施例还提供一种图像分类模型的训练装置,所述装置包括:第一训练模块、识别模块、第一确定模块和第二训练模块;
36.所述第一训练模块,用于使用第一训练数据训练图像分类模型;其中,所述第一训练数据至少包括:第一样本图像;
37.所述识别模块,用于通过训练得到的所述图像分类模型识别所述第一训练数据得到识别结果;
38.所述第一确定模块,用于根据所述识别结果,确定出训练得到的所述图像分类模型识别错误的第二训练数据;
39.所述第二训练模块,用于将至少部分所述第二训练数据输入至所述图像分类模型,继续训练所述图像分类模型。
40.上述方案中,所述第一确定模块,具体用于:根据所述训练得到的所述图像分类模型,确定所述第一样本图像对应的识别结果的损失值;
41.将识别结果为错误且损失值大于第一预设值的第一样本图像确定为所述第二训练数据。
42.上述方案中,所述第一训练模块,具体用于:确定第一样本图像对应的用户信息。
43.上述方案中,所述第一训练模块,具体用于:根据所述用户信息生成用户标签;从所述第一样本图像中截取出目标区域和用户标签,将所述目标区域和所述用户标签输入至所述图像分类模型。
44.上述方案中,所述装置应用于第一设备,所述第一确定模块,具体用于:确定来源所述第一设备的第一类训练数据,和/或,源自第二设备的第二类训练数据;
45.使用明文的第一类训练数据训练所述图像分类模型;和/或,使用密文的所述第二类训练数据训练所述图像分类模型;其中,所述第二设备为不同于所述第一设备的任意设备。
46.上述方案中,所述装置还包括:接收模块、请求模块、预测模块、加密模块和判断模块;
47.所述接收模块,用于从第三设备接收加密密钥,其中,所述第三设备为分别与所述第一设备和所述第二设备连接的中间设备;
48.所述请求模块,用于向所述第三设备请求所述第二类训练数据;
49.所述预测模块,用于将所述第二类训练数据输入到所述图像分类模型后,得到所述图像分类模型输出的预测标签;
50.所述加密模块,用于利用所述加密密钥加密所述预测标签,并将加密后的所述预测标签发送给所述第三设备,其中,所述加密后的所述预测标签,用于供所述第三设备转发给所述第二设备并为所述第一设备确定损失值;
51.所述接收模块,用于接收所述第三设备返回的所述第二设备提供的损失值;
52.所述判断模块,用于基于所述损失值,判断是否继续训练所述图像分类模型。
53.上述方案中,所述第二训练模块,具体用于:
54.在一批所述第一样本数据的训练次数未达到预定次数且存在所述第二训练数据时,继续利用所述第二训练数据训练所述图像分类模型;
55.所述第二训练模块,还用于:在一批所述第一样本数据的训练次数达到所述预定次数时,结束使用本批所述第一样本数据训练所述图像分类模型。
56.本发明实施例还提供一种图像分类模型的训练装置,所述装置包括:处理器和用于存储能够在处理器上运行的计算机程序的存储器;
57.其中,所述处理器用于运行所述计算机程序时,执行上述任意一种图像分类模型的训练方法的步骤。
58.本发明实施例还提供了一种计算机存储介质,其特征在于,所述计算机存储介质存储有计算机可执行指令;所述计算机可执行指令被处理器执行后,能够实现上述一种图像分类模型的训练方法的步骤。
59.本实施例中,使用第一训练数据训练图像分类模型后,通过训练得到的所述图像分类模型对所述第一训练数据进行识别,确定出当前训练得到的所述图像分类模型识别错误的第二训练数据,并将至少部分所述第二训练数据输入至所述图像分类模型,继续训练所述图像分类模型,如此,在训练过程中,确定出当前图像分类模型识别不准确的图片,并对当前图像分类模型识别不准确的训练图片进行重复多次训练,进而加快了图像分类模型的收敛,提高了模型训练的训练速度。
附图说明
60.为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
61.图1为本发明提供的图像分类模型的训练方法的流程示意图;
62.图2为本发明实施例提供的图像分类模型的训练方法的流程示意图;
63.图3为本发明实施例提供的图像分类模型网络构建过程中的数据交互流程示意图;
64.图4为本发明实施例提供的确定第二类训练数据的方法的流程示意图;
65.图5为本发明实施例提供的一种图像分类模型的训练装置的结构示意图;
66.图6为本发明实施例提供的另一种图像分类模型的训练装置的结构示意图。
具体实施方式
67.为了使本发明的目的、技术方案和优点更加清楚,下面将结合附图对本发明作进一步地详细描述,所描述的实施例不应视为对本发明的限制,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
68.在不冲突的情况下,本技术中的实施例及实施例中的特征可以相互任意组合。在附图的流程图示出的步骤可以在诸如一组计算机可执行指令的计算机系统中执行。并且,虽然在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤。
69.除非另有定义,本文所使用的所有的技术和科学术语与属于本发明的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本发明实施例的目的,不是旨在限制本发明。
70.对本发明实施例进行进一步详细说明之前,对本发明实施例中涉及的通信中的名词和术语进行说明。
71.本发明实施例提供一种图像分类模型的训练方法,图1为本发明提供的图像分类模型的训练方法的流程示意图;如图1所示,所述方法包括:
72.步骤s101:使用第一训练数据训练图像分类模型;其中,所述第一训练数据至少包括:第一样本图像;
73.在本实施例中,所述第一训练数据为影像图片,所述第一训练数据可以为存储于本地数据库中的患者影像图片;示例性地,所述第一样本图像可为:患者病灶组织的影像图
片,影像图片类型包括但不限于:电脑断层扫描(computerized tomography,ct)图像、数字减影血管造影(digital subtraction angiography,dsa)图像、核磁共振成像(nmri,nuclear magnetic resonance imaging)等。
74.进一步地,所述第一训练数据还包括:第一样本图像对应的用户信息。
75.具体地,所述第一样本图像对应的用户信息可以包括但不限于以下至少之一:
76.影像图像对应的患者年龄;
77.患者性别;
78.病史信息;
79.病症信息。
80.进一步地,所述使用第一训练数据训练图像分类模型,包括:根据所述用户信息生成用户标签;从所述第一样本图像中截取出目标区域和用户标签,将所述目标区域和所述用户标签输入至所述图像分类模型。
81.具体地,所述用户标签用于指示用户的基本属性;所述目标区域可以为患者病灶组织的影像图片中的病变区域,所述目标区域为标注过的病灶组织所在区域。
82.具体来说,通过图像分割方法,从所述第一样本图像中提取出标注了病灶标签的病灶组织所在区域,将提取出的标注了病灶标签的病灶组织所在区域和所述用户标签输入至所述图像分类模型中,对所述图像分类模型进行训练。
83.在本实施例中,通过样本图像对应的用户标签信息和目标区域同时作为所述图像分类模型的输入数据进行训练,区别于现有技术中仅使用图像作为图像分类模型的输入,实现了在图像分类模型的分类网络上加入患者的年龄、性别、病史信息等,引入多维度的训练数据对于模型进行训练,提高了分类的准确度。
84.步骤s102:通过训练得到的所述图像分类模型识别所述第一训练数据得到识别结果;
85.具体地,所述训练得到的所述图像分类模型为在根据一批所述第一训练数据对所述图像分类模型训练结束后,所获得的图像分类模型;所述第一训练数据为当前批次的所述第一训练数据;所述识别结果为所述训练得到的所述图像分类模型对所述当前批次的所述第一训练数据进行识别后输出的预测标签。
86.具体来说,在根据一批所述第一训练数据对所述图像分类模型训练结束后,将当前批次的所述第一训练数据输入至所述训练得到的所述的图像分类模型,进行识别,得到当前批次的所述第一训练数据对应的预测标签。
87.步骤s103:根据所述识别结果,确定出训练得到的所述图像分类模型识别错误的第二训练数据;
88.具体来说,将当前批次的所述第一训练数据对应的预测标签和所述第一训练数据标注的病灶标签进行对比,判断所述预测标签是否正确;若错误,则将错误的所述第一训练数据作为第二训练数据;此处的病灶标签可为训练数据的实际标签,可以是医生等人工标注的标签,也可以是其他模型标注的但是被认定正确的标签。
89.进一步地,所述根据所述识别结果,确定出训练得到的所述图像分类模型识别错误的第二训练数据,包括:根据所述训练得到的所述的图像分类模型,确定所述第一样本图像对应的识别结果的损失值;将识别结果为错误且损失值大于第一预设值的第一样本图像
确定为所述第二训练数据。
90.具体地,所述第一预设值可以所述图像分类模型自定义的损失值,也可以为用户定义的损失值,这里不做具体限定。
91.需要说明的是,所述第一预设值为不为0的正数值;在一些实施例中,所述第一预设值可以为:对当前批次的所述第一训练数据的损失值进行排序,将排序靠前的第n个所述当前批次的所述第一训练数据的损失值作为第一预设值,其中,n为大于零的正整数。
92.在另一些实施例中,所述第一预设值可以为:对当前批次所述第一训练数据的损失值中的最大值,与预设比例相乘后得到的预设值。
93.具体来说,通过所述训练得到的所述图像分类模型中损失函数,确定所述第一样本图像预测标签和所述第一样本图像标注的病灶标签之间的差异,确定出所述第一样本图像所对应的损失值。若图像分类模型识别得到的标签与所述第一样本图像原本的标签信息之间的差异越大,则所述损失值越大。
94.在一些实施例中,在确定当前批次的所述第一训练数据对应所述预测标签为错误时,将错误的预测标签对应所述第一训练数据作为第二训练数据;在另一些实施例中,对当前批次的所述第一训练数据的损失值进行排序,将排序靠前的第n个所述当前批次的所述第一训练数据的损失值作为第一预设值,其中,n为大于零的正整数;确定出损失值大于第一预设值的当前批次的所述第一训练数据;并判断损失值大于第一预设值的当前批次的所述第一训练数据的预测标签是否正确,将损失值大于第一预设值且预测标签为错误的当前批次的所述第一训练数据作为第二训练数据。
95.在另一些实施例中,确定出当前批次所述第一训练数据的损失值中的最大值,对当前批次所述第一训练数据的损失值中的最大值,按照预设比例运算后,得到第一预设值;确定出损失值大于第一预设值的当前批次的所述第一训练数据,并判断损失值大于第一预设值的当前批次的所述第一训练数据的预测标签是否正确,将损失值大于第一预设值且预测标签为错误的当前批次的所述第一训练数据作为第二训练数据。
96.步骤s104:将至少部分所述第二训练数据输入至所述图像分类模型,继续训练所述图像分类模型。
97.具体地,所述至少部分所述第二训练数据可以为用户选择的部分所述第二训练数据,也可以为表征同种病灶组织的所述第二训练数据,还可以为全部所述第二训练数据。
98.在本实施例中,使用第一训练数据训练图像分类模型后,通过训练得到的所述图像分类模型对所述第一训练数据进行识别,确定出训练得到的所述图像分类模型识别错误的第二训练数据,并将至少部分所述第二训练数据输入至所述图像分类模型,继续训练所述图像分类模型,如此,在训练过程中,确定出当前图像分类模型识别不准确的图片,并对当前图像分类模型识别不准确的训练图片进行重复多次训练,进而加快了图像分类模型的收敛,提高了模型训练的训练速度。
99.进一步地,在所述步骤s101中,所述第一训练数据包括:来源所述第一设备的第一类训练数据,和/或,源自第二设备的第二类训练数据;所述使用第一训练数据训练图像分类模型,包括:使用明文的第一类训练数据训练所述图像分类模型;和/或,使用密文的所述第二类训练数据训练所述图像分类模型;其中,所述第二设备为不同于所述第一设备的任意设备。
100.具体地,当所述第一训练数据为来源所述第一设备的训练数据时,所述第一类训练数据包括但不限于存储于第一设备本地数据库中的患者影像图片;所述明文的第一类训练数据为未经过加密的明文状态的所述第一类训练数据对应的用户信息、用户标签和目标区域,使用明文状态的第一类训练数据训练所述图像分类模型。
101.在第一设备自身的数据不够时,可以引入第二设备的数据,如此增加了图像分类模型的样本数,采用这种方式被训练出的图像分类模型具有泛化能力强和分类结果精确的特点。
102.在本公开实施例中,第二类训练数据是被加密后的训练数据,该训练数据被加密后是密文,但是加密之后仍然可以体现病灶特点等,如此一方面可以确保数据的安全性,另一方面还可以增加图像分类模型的训练样本数。
103.例如,以训练图像为病灶阻止样本为例,进行加密时仅对能够确定患者身份部分的图像内容进行加密,而病灶组织的成像区域并不加密,或者,调整病灶组织在图像中的位置,从而生成新的训练图像等。
104.当所述第一训练数据为源自第二设备的第二类训练数据时,所述第二类训练数据为经过加密处理的存储于第二设备本地数据库中的患者影像图片;所述密文状态的第二类训练数据为加密后的所述第一类训练数据对应的用户信息、用户标签和目标区域,使用加密后的第二类训练数据训练所述图像分类模型;这里,所述第二类训练数据对应的用户信息与所述第一类训练数据对应的用户信息相同,即所述第二类训练数据和所述第一类训练数据为同一患者在不同医院的患者影像图片;进行加密的加密方式可以为同态加密、rsa加密、密钥共享加密等,这里不作具体限定。
105.进一步地,当所述第一训练数据为源自第二设备的第二类训练数据时,所述方法还包括:从第三设备接收加密密钥,其中,所述第三设备为分别与所述第一设备和所述第二设备连接的中间设备;向所述第三设备请求所述第二类训练数据。
106.具体地,所述第二类训练数据为加密后的第三训练数据;所述第三训练数据为所述第一训练数据中的用户在第二设备中的训练数据;这里,所述第三训练数据至少包括:标注了病灶标签的第二样本图像和所述第二样本图像对应的用户信息;所述第二类训练数据中的病灶标签为掩码形式表示的病灶标签;所述掩码数据为对原病灶标签按位运算或逻辑运算得出新的用二进制表示的数据。
107.在本实施例中,通过将第二样本图像中标注的病灶标签用掩码数据来表示,可以提高训练数据在交互过程中的安全性,极大程度保障了病患的隐私信息。
108.这里,所述第三设备为联邦学习模型,所述联邦学习模型能有效帮助多个机构在满足用户隐私保护、数据安全和政府法规的要求下,进行数据不共享的联合建模,通过联邦系统中加密机制可以进行参数交换。
109.具体来说,所述向所述第三设备请求所述第二类训练数据具体为:
110.所述第一设备向所述第三设备请求加密密钥,通过所述加密密钥对所述第一类训练数据进行哈希加密,生成第一加密数据集合并发送给所述第三设备;所述第二设备也向所述第三设备请求所述加密密钥,通过所述加密密钥对所述第二设备中的样本数据进行哈希加密,生成第二加密数据集合并发送给所述第三设备;所述第二设备中的样本数据至少包括:所述第二设备中的样本图像和所述第二设备中的样本图像对应的用户信息。
111.通过所述第三设备对所述第一加密数据集合和所述第二加密数据集合做交集,确定所述第一加密数据集合和所述第二加密数据集合相同的数据,根据确定出的所述相同的数据,确定交集集合;所述第三设备将交集集合发送给所述第二设备后,所述第二设备根据所述交集集合,对所述第二设备中的样本数据进行顺序对比,确定所述第一训练数据中的用户在第二设备中的训练数据,即第三训练数据;所述第二设备通过所述加密密钥对所述第三训练数据再次加密,生成第二类训练数据。
112.具体地,在一些实施例中,所述相同的数据可以为加密后的第一样本图像和加密后的第二样本图像相同;将与第一样本图像相同的所述第二设备中的样本图像对应的第二加密数据确定为交集集合中的数据;
113.在另一些实施例中,所述相同的数据也可以为加密后的第一样本图像对应的用户信息和加密后第二样本图像对应的用户信息相同;将与第一样本图像对应的用户信息相同的所述第二设备中的样本图像对应的用户信息对应的第二加密数据确定为交集集合中的数据。
114.基于所述第一设备和所述第二设备使用的是相同的加密密钥,确定出的所述交集集合对应的明文数据也相同,由此,可以确定出所述第一设备和所述第二设备中同一患者的影像图片。
115.这里,所述第二加密集合是对所述第二设备中的样本图像以及所述第二设备中的样本图像对应的用户信息按照顺序进行哈希加密生成的;在本实施例中,通过对第一设备和第二设备进行哈希加密,基于哈希加密不可逆加密,不可被解密的特性,将医院数据加密处理后进行交互,且采用不可被解密的加密方式进行加密,保障了数据隐私性和安全性。
116.以下,以一个具体的实施例对本发明实施例中提供的所述图像分类模型的训练方法进行说明,图2为本发明实施例提供的图像分类模型的训练方法的流程示意图,图3为本发明实施例提供的图像分类模型网络构建过程中的数据交互流程示意图,图4为本发明实施例提供的确定第二类训练数据的方法的流程示意图;
117.以一个具体的实施例对本发明实施例中的图像分类模型网络构建过程中的数据交互进行说明,如图3所示,图3为本发明实施例提供的图像分类模型网络构建过程中的数据交互流程示意图;
118.对所述图像分类模型进行backbone网络建立,所述图像分类模型的backbone包含至少一个卷积层、至少一个全局最大池化层、至少一个全连接层、批次归一化和修正线性单元等组成,用于对皮肤病图像进行特征提取。将backbone输出的n维向量,与年龄、性别等信息拼接,组成一个n 2维的向量,把新的向量输入全连接层(fc)中,最后经过sigmoid函数分别得到该图片的损失值。
119.所述backbone网络可以分成具体的三部分,其中,网络的第一部分:
120.由输入的224*224的三通道的rgb图像,首先传入卷积核大小为3*3,步长为1;在该层中有64个卷积核进行采样,经过连续两层相同的卷积层采样之后,输出为224*224*64的特征图。
121.经过上一层的卷积层之后,传入下一层的最大池化层中,进行2*2的池化操作,输出为112*112*64的特征图。
122.类似的,接下来的卷积层与池化层与上一步类似,只是将conv3-64的卷积核替换
成conv3-128的卷积核,相应的,卷积层输出的特征图为112*112*128,池化层输出特征图为56*56*128。
123.进一步的,传入卷积核大小为3*3,步长为1;在该层中有256个卷积核进行采样,该阶段由连续3层构成,经过两层相同的卷积层采样之后,得到56*56*256的特征图。继续将上一层卷积层的输出结果传入到最大池化层中,得到28*28*256的特征图。
124.类似的,之后的卷积层与池化层与上述的操作类似,只是将conv3-256的卷积核替换成conv3-512的卷积核,相应的,卷积层输出的特征图为56*56*512,池化层输出的特征图为28*28*512。
125.最后将28*28*512传入conv3-1024,步长为2的卷积核中,得到28*28*1024大小的特征图,至此,网络的第一部分结束。
126.网络的第二部分:
127.对于网络的第二部分的输入,将其传入到conv3-1024,步长为1的卷积核构成的卷积层中处理,得到28*28*1024的特征图,即为图4中的输出1,该特征图继续传入下一层中,作为全连接层的一部分,输入第三部分网络中进行处理。
128.持续对上一层的结果进行下一层卷积,使之进一步得到后续的特征,经过conv3-512,步长分别为1和步长值分别为2的卷积层之后,得到28*28*512的特征图,即为图4中的输出2。将28*28*512的特征图继续传入之后的卷积层,作为第三部分网络中的一部分,传入第三部分网络进行处理。
129.进一步,继续经过conv3-256,步长分别为1和步长值分别为2的卷积层之后,得到28*28*256的特征图,上图输出3。其继续传入之后的卷积层,也作为网络第三部分中的一部分,传入网络的第三部分进行处理。
130.最后,经过conv3-128,步长分别为1和步长值分别为2的卷积,此时所得到两个28*28*128特征图,即为图4中的输出4与5。两个28*28*128特征图都作为第三部分网络的输入值;至此,网络的第二部分结束。
131.网络的第三部分:
132.由之前的网络的第二部分所得的特征图,与患者病例中的年龄病史等信息相组合,构成新的特征向量参数,经过全局平均池化之后,再传入到全连接层之中,经过sigmoid函数激活后,可以得到一个多维特征向量,用于确定多种分类疾病的损失值。
133.以下,以一个具体的实施例对本发明实施例中提供的所述图像分类模型的训练方法进行说明,图3为本发明实施例提供的图像分类模型的训练方法的流程示意图,如图3所示:
134.步骤1:图像分类模型的网络建立,并对训练数据进行图像预处理。
135.步骤2:使用明文的第一医院中的第一类训练数据训练所述图像分类模型。
136.步骤3:将当前批次的第一类训练数据输入至所述图像分类模型中进行识别,确定出在训练得到的所述图像分类模型识别过程中置信度低的训练图片。
137.具体地,确定每张训练图片的损失值和所述训练图片对应的识别结果,将置信度低的训练图片构成单独的一个图片序列,并将置信度低的训练图片重新加入到输入图像中,额外的对于该部分的图片进行训练,增强模型对于置信度较低的图片的学习。
138.这里,置信度较低的图片判定的标准为:
139.a)、在图像分类模型训练的过程中,每一次epoch结束后,将训练图片导入当前的图像分类模型进行预测,如果预测错误,则该训练图片被认为满足置信度低的第一条件。
140.b)、在图像分类模型训练的过程中,在每一次epoch中,每一张训练图片均有器对应的损失值,选择出前20张损失值较高的图片,与满足置信度低的第一条件的训练图片做交集,交集中的图片便可以认为为置信度低。
141.步骤4:将确定出的置信度低的训练图片,输入至当前所述图像分类模型,继续执行步骤2。
142.步骤5:在每一次epoch结束之后,使用密文的第二医院中的第二类训练数据训练所述图像分类模型。
143.具体地,如图4所示,图4为本发明实施例提供的确定第二类训练数据的方法的流程示意图;
144.所述第一医院中的所述第一类训练数据表示为id1,其中,id1={u1,u2,u3...ua};所述第二医院中的训练数据表示为id2,其中,id2={v1,v2,v3...vb};
145.这里,u1至少包括所述第一样本图像和所述第一样本图像所对应的用户信息;v1至少包括所述第二设备中的样本图像和所述第二设备中的样本图像所对应的用户信息。
146.对id1进行哈希序列化,得到数据集a:h(id1)={h(u1),h(u2),h(u3)...h(ua)},对id2进行哈希序列化,得到数据集b:h(id2)={h(v1),h(v2),h(v3)...h(vb)}
147.对h(id1)中的每一个序列均采用rsa加密,得到其中,rie代表数据集a中第i个元素的rsa私钥的e次方;类似的,对于h(id2)中的每一个序列也均采用rsa加密,得到由于哈希加密是不可逆加密,不可被解密,满足医院数据传输要求。
148.所述第二医院将id
’2发送至所述联邦学习模型中,所述第一医院将id
’1发送至所述联邦学习模型中;所述联邦学习模型对id
’2和id
’1做交集,确定出交集集合。因为所述第一医院和所述第二医院采用的是相同的rsa加密密钥,因此,对id
’2和id
’1做交集可以确定出h(id1)和h(id2)中相同的序列,这里数据加密时按照顺序进行加密的,在确定出h(id1)和h(id2)中相同的序列后,基于h(id2),将确定出的相同序列和数据id2进行对比,确定出在第一医院和第二医院中同一患者的就诊数据。
149.例如:第一医院id1中u1的所述第一样本图像所对应的用户信息和第二医院id2中v1的所述第二设备中的样本图像所对应的用户信息相同;第一医院id1中u3的所述第一样本图像所对应的用户信息和第二医院id2中v3的所述第二设备中的样本图像所对应的用户信息相同;则在联邦学习模型中,对id
’2和id
’1做交集后将得到id
’1∩id
’2={r
1e
h(u1)r
1e
h(v1),0,r
3e
h(u3)r
3e
h(v3)

,0},将id
’1∩id
’2和h(id2)={h(v1),h(v2),h(v3)...h(vb)}按照顺序进行对比,则可以确定出第三训练数据h={h(v1),h(v3)},所述第二设备通过联邦学习模型发送的加密密钥对第三训练数据进行再次加密,生成第二类训练数据;所述联邦学习模型将确定出的第二类训练数据发送给第一医院,第一医院对第二类训练数据进行识别,并更新权重;
150.具体地,第一医院初始化权重θa,第二医院对v1中医生标注的病灶标签进行掩码化,生成对v3中医生标注的病灶标签进行掩码化,生成通过加密密钥对h={h(v1),
h(v3)}进行再次加密生成第二类训练数据h’={r
1e
h(v1),r
3e
h(v3)},其中,加密后生成加密后生成第一医院通过图像分类模型对第二类训练数据中的r
1e
h(v1)进行识别,生成预测标签对第二类训练数据中的r
3e
h(v3)进行识别,生成预测标签通过加密密钥对进行加密后生成对进行加密后生成第一医院将和发送给第二医院,第二医院根据和计算出损失值将和进行对比,中间变量第二医院根据和计算出损失值将和进行对比,计算出中间变量将所有所述第二类训练数据对应的损失值和中间变量发送给联邦学习模型,由联邦学习模型对损失值和中间变量解密,生成和第一医院根据和计算出梯度加加
151.步骤6:确定图像分类模型的训练次数;
152.如果图像分类模型的训练次数没有达到预设的训练次数且存在置信度低的训练图片,则执行步骤2。如果达到预设的训练次数值,则执行步骤7。
153.步骤7:结束训练。
154.如图5所示,图5为本发明实施例提供的一种图像分类模型的训练装置的结构示意图,所述装置包括:第一训练模块501、识别模块502、第一确定模块503和第二训练模块504;
155.所述第一训练模块501,用于使用第一训练数据训练图像分类模型;其中,所述第一训练数据至少包括:第一样本图像;
156.所述识别模块502,用于所述识别模块,用于通过训练得到的所述图像分类模型识别所述第一训练数据得到识别结果;
157.所述第一确定模块503,用于根据所述识别结果,确定出训练得到的所述图像分类模型识别错误的第二训练数据;
158.所述第二训练模块504,用于将至少部分所述第二训练数据输入至所述图像分类模型,继续训练所述图像分类模型。
159.具体地,所述第一确定模块503,还用于:根据所述训练得到的所述图像分类模型,确定所述第一样本图像对应的识别结果的损失值;将识别结果为错误且损失值大于第一预设值的第一样本图像确定为所述第二训练数据。
160.具体地,所述第一训练模块501,还用于:确定第一样本图像对应的用户信息。
161.具体地,所述第一训练模块501,还用于:根据所述用户信息生成用户标签;从所述第一样本图像中截取出目标区域和用户标签,将所述目标区域和所述用户标签输入至所述图像分类模型。
162.具体地,所述装置应用于第一设备,所述第一确定模块501,还用于:确定来源所述
第一设备的第一类训练数据,和/或,源自第二设备的第二类训练数据;使用明文的第一类训练数据训练所述图像分类模型;和/或,使用密文的所述第二类训练数据训练所述图像分类模型;其中,所述第二设备为不同于所述第一设备的任意设备。
163.具体地,所述装置还包括:接收模块505、请求模块506、预测模块507、加密模块508和判断模块509;
164.所述接收模块505,用于从第三设备接收加密密钥,其中,所述第三设备为分别与所述第一设备和所述第二设备连接的中间设备;
165.所述请求模块506,用于向所述第三设备请求所述第二类训练数据;
166.所述预测模块507,用于将所述第二类训练数据输入到所述图像分类模型后,得到所述图像分类模型输出的预测标签;
167.所述加密模块508,用于利用所述加密密钥加密所述预测标签,并将加密后的所述预测标签发送给所述第三设备,其中,所述加密后的所述预测标签,用于供所述第三设备转发给所述第二设备并为所述第一设备确定损失值;
168.所述接收模块505,还用于接收所述第三设备返回的所述第二设备提供的损失值;
169.所述判断模块509,用于基于所述损失值,判断是否继续训练所述图像分类模型。
170.具体地,所述第二训练模块504,还用于:在一批所述第一样本数据的训练次数未达到预定次数且存在所述第二训练数据时,继续利用所述第二训练数据训练所述图像分类模型;所述第二训练模块504,还用于:在一批所述第一样本数据的训练次数达到所述预定次数时,结束使用本批所述第一样本数据训练所述图像分类模型。
171.为实现本发明实施例的方法,本发明实施例提供另一种图像分类模型的训练装置,具体来说,如图6所示,图6为本发明实施例提供的另一种图像分类模型的训练装置的结构示意图;所述装置60包括处理器601和用于存储能够在处理器上运行的计算机程序的存储器602;
172.其中,所述处理器601用于运行所述计算机程序时,执行:使用第一训练数据训练图像分类模型;其中,所述第一训练数据至少包括:第一样本图像;通过训练得到的所述图像分类模型识别所述第一训练数据得到识别结果;根据所述识别结果,确定出训练得到的所述图像分类模型识别错误的第二训练数据;将至少部分所述第二训练数据输入至所述图像分类模型,继续训练所述图像分类模型。
173.在一实施例中,所述处理器601还用于运行所述计算机程序时,执行:根据所述训练得到的所述图像分类模型,确定所述第一样本图像对应的识别结果的损失值;将识别结果为错误且损失值大于第一预设值的第一样本图像确定为所述第二训练数据。
174.在一实施例中,所述处理器601还用于运行所述计算机程序时,执行:确定所述第一训练数据还包括的第一样本图像对应的用户信息。
175.在一实施例中,所述处理器601还用于运行所述计算机程序时,执行:根据所述用户信息生成用户标签;从所述第一样本图像中截取出目标区域和用户标签,将所述目标区域和所述用户标签输入至所述图像分类模型。
176.在一实施例中,所述处理器601还用于运行所述计算机程序时,执行:所述第一训练数据包括:来源所述第一设备的第一类训练数据,和/或,源自第二设备的第二类训练数据;所述使用第一训练数据训练图像分类模型,包括:使用明文的第一类训练数据训练所述
图像分类模型;和/或,使用密文的所述第二类训练数据训练所述图像分类模型;其中,所述第二设备为不同于所述第一设备的任意设备。
177.在一实施例中,所述处理器601还用于运行所述计算机程序时,执行:从第三设备接收加密密钥,其中,所述第三设备为分别与所述第一设备和所述第二设备连接的中间设备;向所述第三设备请求所述第二类训练数据;所述处理器601还用于运行所述计算机程序时,执行:将所述第二类训练数据输入到所述图像分类模型后,得到所述图像分类模型输出的预测标签;利用所述加密密钥加密所述预测标签,并将加密后的所述预测标签发送给所述第三设备,其中,所述加密后的所述预测标签,用于供所述第三设备转发给所述第二设备并为所述第一设备确定损失值;接收所述第三设备返回的所述第二设备提供的损失值;基于所述损失值,判断是否继续训练所述图像分类模型。
178.在一实施例中,所述处理器601还用于运行所述计算机程序时,执行:在一批所述第一样本数据的训练次数未达到预定次数且存在所述第二训练数据时,继续利用所述第二训练数据训练所述图像分类模型;
179.所述处理器601还用于运行所述计算机程序时,执行:在一批所述第一样本数据的训练次数达到所述预定次数时,结束使用本批所述第一样本数据训练所述图像分类模型。
180.需要说明的是:上述实施例提供的图像分类模型的训练装置与图像分类模型的训练方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。
181.当然,实际应用时,如图6所示,该装置60还可以包括:至少一个网络接口603。图像分类模型的训练装置60中的各个组件通过总线系统604耦合在一起。可理解,总线系统604用于实现这些组件之间的连接通信。总线系统604除包括数据总线之外,还包括电源总线、控制总线和状态信号总线。但是为了清楚说明起见,在图6中将各种总线都标为总线系统604。其中,所述处理器601的个数可以为至少一个。网络接口603用于图像分类模型的训练装置60与其他设备之间有线或无线方式的通信。
182.本发明实施例中的存储器602用于存储各种类型的数据以支持图像分类模型的训练装置60的操作。
183.上述本发明实施例揭示的方法可以应用于处理器601中,或者由处理器601实现。处理器601可能是一种集成电路芯片,具有信号的处理能力。在实现过程中,上述方法的各步骤可以通过处理器601中的硬件的集成逻辑电路或者软件形式的指令完成。上述的处理器601可以是通用处理器、数字信号处理器(dsp,digital signal processor),或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。处理器601可以实现或者执行本发明实施例中的公开的各方法、步骤及逻辑框图。通用处理器可以是微处理器或者任何常规的处理器等。结合本发明实施例所公开的方法的步骤,可以直接体现为硬件译码处理器执行完成,或者用译码处理器中的硬件及软件模块组合执行完成。软件模块可以位于存储介质中,该存储介质位于存储器602,处理器601读取存储器602中的信息,结合其硬件完成前述方法的步骤。
184.在示例性实施例中,图像分类模型的训练装置60可以被一个或多个应用专用集成电路(asic,application specific integrated circuit)、dsp、可编程逻辑器件(pld,programmable logic device)、复杂可编程逻辑器件(cpld,complex programmable logic device)、现场可编程门阵列(fpga,field-programmable gate array)、通用处理器、控制
器、微控制器(mcu,micro controller unit)、微处理器(microprocessor)、或其他电子元件实现,用于执行前述方法。
185.在示例性实施例中,本发明实施例还提供了一种计算机可读存储介质,例如包括计算机程序的存储器602,上述计算机程序可由图像分类模型的训练装置60的处理器601执行,以完成前述方法所述步骤。
186.具体地,本发明实施例还提供了一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器运行时,执行:使用第一训练数据训练图像分类模型;其中,所述第一训练数据至少包括:第一样本图像;通过训练得到的所述图像分类模型识别所述第一训练数据得到识别结果;根据所述识别结果,确定出训练得到的所述图像分类模型识别错误的第二训练数据;将至少部分所述第二训练数据输入至所述图像分类模型,继续训练所述图像分类模型。
187.在一实施例中,所述计算机程序被处理器运行时,执行:根据所述训练得到的所述图像分类模型,确定所述第一样本图像对应的识别结果的损失值;将识别结果为错误且损失值大于第一预设值的第一样本图像确定为所述第二训练数据。
188.在一实施例中,所述计算机程序被处理器运行时,执行:所述第一训练数据还包括:第一样本图像对应的用户信息。
189.在一实施例中,所述计算机程序被处理器运行时,执行:根据所述用户信息生成用户标签;从所述第一样本图像中截取出目标区域和用户标签,将所述目标区域和所述用户标签输入至所述图像分类模型。
190.在一实施例中,所述计算机程序被处理器运行时,执行:所述第一训练数据包括:来源所述第一设备的第一类训练数据,和/或,源自第二设备的第二类训练数据;所述使用第一训练数据训练图像分类模型,包括:使用明文的第一类训练数据训练所述图像分类模型;和/或,使用密文的所述第二类训练数据训练所述图像分类模型;其中,所述第二设备为不同于所述第一设备的任意设备。
191.在一实施例中,所述计算机程序被处理器运行时,执行:从第三设备接收加密密钥,其中,所述第三设备为分别与所述第一设备和所述第二设备连接的中间设备;向所述第三设备请求所述第二类训练数据;所述计算机程序被处理器运行时,执行:将所述第二类训练数据输入到所述图像分类模型后,得到所述图像分类模型输出的预测标签;利用所述加密密钥加密所述预测标签,并将加密后的所述预测标签发送给所述第三设备,其中,所述加密后的所述预测标签,用于供所述第三设备转发给所述第二设备并为所述第一设备确定损失值;接收所述第三设备返回的所述第二设备提供的损失值;基于所述损失值,判断是否继续训练所述图像分类模型。
192.在一实施例中,所述计算机程序被处理器运行时,执行:在一批所述第一样本数据的训练次数未达到预定次数且存在所述第二训练数据时,继续利用所述第二训练数据训练所述图像分类模型;所述计算机程序被处理器运行时,执行:在一批所述第一样本数据的训练次数达到所述预定次数时,结束使用本批所述第一样本数据训练所述图像分类模型。
193.需要说明的是:本发明实施例提供的计算机可读存储介质可以是fram、rom、prom、eprom、eeprom、flash memory、磁表面存储器、光盘、或cd-rom等存储器;也可以是包括上述存储器之一或任意组合的各种设备。
194.以上所述,仅为本发明的实施例而已,并非用于限定本发明的保护范围。凡在本发明的精神和范围之内所作的任何修改、等同替换和改进等,均包含在本发明的保护范围之内。
再多了解一些

本文用于创业者技术爱好者查询,仅供学习研究,如用于商业用途,请联系技术所有人。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献