一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

一种基于知识蒸馏的胸部放射影像疾病分类模型轻量化方法

2022-12-13 22:46:42 来源:中国专利 TAG:


1.本发明涉及一种基于知识蒸馏的胸部放射影像疾病分类模型轻量化方法,属于计算机视觉技术领域。


背景技术:

2.胸部疾病是威胁人类健康的一大难题,全世界每年有数亿人受到胸部疾病的困扰,如果不及时进行治疗,将会对患者带来很大的影响,甚至危及患者的生命。计算机辅助诊断(computer aided diagnosis,cad)常被用于辅助放射科医生进行高效诊断,与人工阅片相比,计算机辅助诊断具有许多的优点,例如不受到主观的影响、自动提取并筛选视觉特征、随着学习数据量的增大,模型精度迅速提升等等。现有的计算机辅助诊断模型在医学图像的分类、分割和检测任务上自动完成且取得了不错的成绩,极大的解决了放射影像处理面对的问题,然而高性能的代价是网络规模的不断扩大,包括计算量的不断扩大以及参数量占用的内存不断扩大。因此,在使得模型的性能指标较高的情况下,如何有效降低参数量,并且提高运行效率变得十分重要。
3.为提高模型效率,在过去的几年中基于深度学习提出了许多解决方法,这些方法一般分为四类:(1)超分辨率网络设计,如he等人设计出了残差网络模型取名为resnet,通过跳跃连接的方式弥补了梯度消失或者梯度爆炸导致无法正常训练的后果。densenet将前面所有层的图像特征作为输入继续添加入后面的层中,最大化利用特征图属性。resnext引入新的维度,变相的减少了计算量。(2)网络剪枝,如han等人在几乎没有降低模型性能的基础上,统过范数排列将模型压缩了十倍以上。li等人将绝对值较小的卷积核删除,并且统计删除每层的卷积核后对模型性能的影响,将影响较大的卷积核的剪枝比例降低,影响较小的卷积核的剪枝比例增大,结果表明统一对模型进行剪枝不如对每层卷积单独考虑进行剪枝。(3)数据量化,如nvidia中的量化技术选用浮点数和整数的最小散度距离值作为是否量化的阈值。dfq模型通过权值均衡化和纠正偏差值能够量化至6-bit而保持和浮点网络接近的性能。(4)知识蒸馏,如zagoruyko等人将注意力机制加入到知识蒸馏模型中去,该方法使用注意力图共享使得学生网络能够关注到教师网络所关注的特征。减少学生网络的注意力图与教师网络的注意力图之间的差异,来让学生网络学习到教师网络的资源分配侧重。xu等人利用对抗性训练来学习损失,然后将教师网络的知识传递到学生网络中去。
4.而深度学习算法应用到医学图像检测及分类任务上时,经常会出现参数量大,内存占用多等问题,且在移动端或者嵌入式设备上难以运行。针对现有的胸部放射影像分类模型参数量过大,运算效率低、而直接将小网络运用于医学图像上进行训练结果不稳定,且效果欠佳等问题。本发明提出了一种基于知识蒸馏的胸部放射影像疾病分类模型,该模型使用知识蒸馏将大模型作为教师网络,小模型作为学生网络,将知识迁移到小模型上,既能提高小模型的精度和稳定性,也能减少疾病分类模型的参数量,提高运行效率。


技术实现要素:

5.本发明涉及一种基于知识蒸馏的胸部放射影像疾病分类模型轻量化方法,以用于解决现有的胸部放射影像分类模型参数量过大,运算效率低、而直接将小网络运用于医学图像上进行训练结果不稳定,且效果欠佳等问题;本发明能够提高模型效率、精度和稳定性。
6.本发明的技术方案是:一种基于知识蒸馏的胸部放射影像疾病分类模型轻量化方法,所述方法的具体包括如下:
7.step1、图卷积神经网络(gcn)模块采用预训练语言模型将医学影像标签转换为glove词嵌入表示,然后采用数据挖掘的方式构建标签关系图矩阵输入到gcn模块,经过两层的图卷积操作提取出疾病标签特征;
8.step2、医学影像从卷积神经网络(cnn)模块输入,经过卷积操作和最大池化操作之后提取出胸部医学影像的特征;
9.step3、把医学影像特征与疾病标签特征融合来预测多标签分类结果;
10.step4、选取resnet18网络模型作为学生网络,resnet18通过残差单元解决深度网络的性能退化问题;
11.step5、选择损失函数进行回归和分类,用教师网络的损失指标指导学生网络的损失指标,得到多标签分类模型学生网络。
12.进一步地,所述step1中的图卷积神经网络(gcn)模块具体包括如下:
13.将每种疾病标签表示为单独一个节点,gcn设置为2层,输入是特征表示矩阵h
l
∈rn×d(其中n是标签的种类,d是标签的词嵌入维度)和标签相关关系矩阵a∈rn×n,目标是在图上学习一个函数f(
·
,
·
),使新节点更新表示为h
l 1
∈rn×n,如gcn的每一层都表示为一个非线性激活函数为:h
l 1
=f(h
l
,a)。本发明的图卷积运算使用的f(
·
,
·
)表示形式为:h
l 1
=h(ah
lwl
),其中h(
·
)指的是非线性操作,本发明中使用leakyrelu激活函数,w
l
∈rd×d′
指的是要学习的转换矩阵。
14.进一步地,所述step1中的构建标签关系图矩阵具体包括如下:
15.采用数据挖掘的方式构建标签关系图矩阵输入到gcn模块,首先统计出所有疾病类别的总数,然后通过数据挖掘找出每种疾病发生的情况下其他疾病发生的数量,即以条件概率的形式构建关系矩阵,定义p(la|lb)表示la标签出现的情况下lb标签发生的概率,如可表示la为气胸(pneumothorax),lb为肺气肿(emphysema),假设气胸出现的情况下肺气肿出现的概率为0.3,而肺气肿出现的情况下气胸出现的概率为0.1。本发明所使用的医学影像数据集疾病标签类别为14种,所以最终构建的标签相关关系矩阵为14
×
14的二维矩阵。
16.进一步地,所述step2中的卷积神经网络(cnn)模块具体包括如下:
17.卷积神经网络(cnn)模块模型选取densenet网络模型,里面一共有四个模块,模块的命名方式为denseblock1到denseblock4,模块之间的区别在于每个块之间的卷积操作以及数量不同。每个denseblock块里面都包含了1*1和3*3的卷积核以及批归一化层,密集网络块之间还有进行下采样操作的过渡层,densenet-121共包含3个过渡层,为了能够顺利进行特征融合并且更好的得到纹路特征,本发明去掉densenet-121网络最后的全连接层,替换为最大池化层。
18.进一步地,所述step3中的融合方法具体包括如下:
19.本发明采取矩阵乘积的方式进行特征融合,如计算公式所示:式中表示总体特征,x为医学影像特征,y为疾病标签特征。然后将总体特征放入多标签分类损失函数求出loss,如计算公式所示:求出loss,如计算公式所示:其中δ(
·
)为sigmoid函数,c表示迭代次数。
20.进一步地,所述step4中的学生网络具体包括如下:
21.resnet18网络模型每两层就会出现一次残差学习,该网络模型分为五部分,分别是convolution1、conv2_x、conv3_x、conv4_x,conv5_x,最后连接了一个池化层。
22.进一步地,所述step5中的损失函数还包括如下:
23.为了使学生网络能够学习到soft target,使用知识蒸馏中的温度参数t来调节知识传递,定义softmax函数为:其中pi表示教师网络第i个输出的概率,xi、xj表示softmax的输入,t为温度系数,当温度增加,softmax的输出分布越来越平缓,信息熵会越来越大,学生网络能够更多的关注到负标签;为了学生网络能够更好地拟合教师网络的分类结果,定义总体损失函数为:loss=(1-a)h(label,y) αh(p,y)t2,其中α表示权重系数,h表示的是交叉熵,label为真实标签结果,y为学生网络标签结果,p表示教师网络总体概率。
24.本发明的有益效果是:使用学生网络来做预测比使用教师网络降低了百分之35的内存占用,并且提高了百分之34的运行速度。进一步的消融实验中,不使用知识蒸馏的情况下平均auc为0.756,使用教师网络指导后为0.817,提升了6个百分点,证明了知识蒸馏确实是有用的,本发明在使总体模型精度尽可能少下降的前提下,实现模型的压缩,极大地提高模型运行效率以及降低内存使用率。
附图说明
25.图1为总体模型及关键组件结构,(teacher model,student model);
具体实施方式
26.实施例1、如图1,一种基于知识蒸馏的胸部放射影像疾病分类模型轻量化方法,所述方法的具体步骤如下:
27.step1、图卷积神经网络(gcn)模块采用预训练语言模型将医学影像标签转换为glove词嵌入表示,然后采用数据挖掘的方式构建标签关系图矩阵输入到gcn模块,经过两层的图卷积操作提取出疾病标签特征;
28.step2、医学影像从卷积神经网络(cnn)模块输入,经过卷积操作和最大池化操作之后提取出胸部医学影像的特征;
29.step3、把医学影像特征与疾病标签特征融合来预测多标签分类结果;
30.step4、选取resnet18网络模型作为学生网络,resnet18通过残差单元解决深度网络的性能退化问题;
31.step5、选择损失函数进行回归和分类,用教师网络的损失指标指导学生网络的损失指标,得到多标签分类模型学生网络。
32.进一步地,所述step1中的图卷积神经网络(gcn)模块具体包括如下:
33.将每种疾病标签表示为单独一个节点,gcn设置为2层,输入是特征表示矩阵h
l
∈rn×d(其中n是标签的种类,d是标签的词嵌入维度)和标签相关关系矩阵a∈rn×n,目标是在图上学习一个函数f(
·
,
·
),使新节点更新表示为h
l 1
∈rn×n,如gcn的每一层都表示为一个非线性激活函数为:h
l 1
=f(h
l
,a)。本发明的图卷积运算使用的f(
·
,
·
)表示形式为:h
l 1
=h(ah
lwl
),其中h(
·
)指的是非线性操作,本发明中使用leakyrelu激活函数,w
l
∈rd×d′
指的是要学习的转换矩阵。
34.进一步地,所述step1中的构建标签关系图矩阵具体包括如下:
35.采用数据挖掘的方式构建标签关系图矩阵输入到gcn模块,首先统计出所有疾病类别的总数,然后通过数据挖掘找出每种疾病发生的情况下其他疾病发生的数量,即以条件概率的形式构建关系矩阵,定义p(la|lb)表示la标签出现的情况下lb标签发生的概率,如可表示la为气胸(pneumothorax),lb为肺气肿(emphysema),假设气胸出现的情况下肺气肿出现的概率为0.3,而肺气肿出现的情况下气胸出现的概率为0.1。本发明所使用的医学影像数据集疾病标签类别为14种,所以最终构建的标签相关关系矩阵为14
×
14的二维矩阵。
36.进一步地,所述step2中的卷积神经网络(cnn)模块具体包括如下:
37.卷积神经网络(cnn)模块模型选取densenet网络模型,里面一共有四个模块,模块的命名方式为denseblock1到denseblock4,模块之间的区别在于每个块之间的卷积操作以及数量不同。每个denseblock块里面都包含了1*1和3*3的卷积核以及批归一化层,密集网络块之间还有进行下采样操作的过渡层,densenet-121共包含3个过渡层,为了能够顺利进行特征融合并且更好的得到纹路特征,本发明去掉densenet-121网络最后的全连接层,替换为最大池化层。
38.进一步地,所述step3中的融合方法具体包括如下:
39.本发明采取矩阵乘积的方式进行特征融合,如计算公式所示:式中表示总体特征,x为医学影像特征,y为疾病标签特征。然后将总体特征放入多标签分类损失函数求出loss,如计算公式所示:求出loss,如计算公式所示:其中δ(
·
)为sigmoid函数,c表示迭代次数。
40.进一步地,所述step4中的学生网络具体包括如下:
41.resnet18网络模型每两层就会出现一次残差学习,该网络模型分为五部分,分别是convolution1、conv2_x、conv3_x、conv4_x,conv5_x,最后连接了一个池化层。
42.进一步地,所述step5中的损失函数还包括如下:
43.为了使学生网络能够学习到soft target,使用知识蒸馏中的温度参数t来调节知识传递,定义softmax函数为:其中pi表示教师网络第i个输出的概率,xi、xj表示softmax的输入,t为温度系数,当温度增加,softmax的输出分布越来越平缓,信息熵会越来越大,学生网络能够更多的关注到负标签;为了学生网络能够更好地拟合教师网络的分类结果,定义总体损失函数为:loss=(1-a)h(label,y) αh(p,y)t2,其中α表示权重系数,h表示的是交叉熵,label为真实标签结果,y为学生网络标签结果,p表示教师网络总体
概率。
44.进一步地,为了验证本发明的效果,由上述步骤训练好的模型输入由某国国立卫生研究院(national institutes of health,nih)整理并公开的大型多标签胸部x射线数据集chestx-ray14对模型进行测验。该试验的环境配置为gpu:nvidia rtx2080;内存:11gb;操作系统:ubuntu 18.04.3;机器学习框架:pytorch。
45.本模型基于图卷积神经网络的胸部放射影像疾病分类模型作为教师网络(teacher model),教师网络由图卷积神经网络(gcn)与卷积神经网络(cnn)模块构成,把运算速度快、内存占用率低的resnet18网络作为学生网络(student model),将教师网络和学生网络进行联合训练,用损失函数对教师和学生网络进行回归和分类,使用教师网络的损失指标指导学生网络的损失指标。预测结果为了客观的对实验进行评价,使用areaundertheroccurves(auc)作为评价指标。
46.利用知识蒸馏方法对胸部dr放射影像进行疾病分类,使用ml chest-gcn作为教师网络,测试生成的resnet18学生网络准确率,并且与教师网络做对照,教师网络和学生网络实验结果均为单独测试得出。实验结果对照如表1所示:
47.表1教师网络和学生网络对照实验结果
[0048][0049]
从表1结果表明,resnet18的疾病分类学生网络在教师网络ml-gcn的指导下,能够将auc平均值提高到0.817的成绩。虽然相较于教师网络下降了三个百分点,不过依旧超过了wang等人的0.738,yao等人的0.803。实验结果证明了知识蒸馏方法的确指导了学生网络进行学习。
[0050]
为了评估训练好的学生网络是否能够提升效率,接下来进行了学生网络的运行速度,以及内存占用情况的实验,运行速度以张数每秒为单位,将其与教师网络ml-gcn做对比,设置相同的实验环境,实验结果如表2所示:
[0051]
表2模型效率对照表
[0052][0053]
从表2结果表明,在同一个实验环境下,使用学生网络来做预测比使用教师网络降低了百分之35的内存占用,并且提高了百分之34的运行速度。
[0054]
进一步的本发明还进行了消融实验,单独训练resnet18网络模型,完整移除了知
识蒸馏环节,不加入教师网络指导,实验结果如表3所示中resnet18所示:
[0055]
表3为消融实验结果
[0056][0057]
如表3所示不使用知识蒸馏的情况下平均auc为0.756,使用教师网络指导后为0.817,提升了6个百分点,证明了知识蒸馏确实是有用的。
[0058]
上面结合附图对本发明的具体实施方式作了详细说明,但是本发明并不限于上述实施方式,在本领域普通技术人员所具备的知识范围内,还可以在不离本发明宗旨的前提下作出各种变化。
再多了解一些

本文用于创业者技术爱好者查询,仅供学习研究,如用于商业用途,请联系技术所有人。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献