Magicode logo
Magicode
3

ECGでの心拍分類+転移学習

2 min read

論文を基に1-D CNN・転移学習・ECGデータ処理法を実装法を中心に学ぶ。

Abstract

ECG Heartbeat Classification: A Deep Transferable Representation

従来のECG解析に関する機械学習メソッドは異なるタスクごとに独立したものであったが、この論文ではタスクごとに知識を再利用できないかを検証することを目的としている。

まずMIT-BIHデータセットで不整脈分類タスク(5クラス)について学習させる。その後得られた学習済みNN(転移学習)に対してPTBデータセットで心筋梗塞2値分類タスクについて学習させた。

結果、平均Accuracyは不整脈分類タスクでは 93.4% で心筋梗塞2値分類タスクでは 95.9%と高精度な予測器を生成することができた。この結果から不整脈分類タスクの知識をうまく心筋梗塞2値分類タスクに転移させたことができたと主張している。

Method & Implementation

Preprocessing

以下のような手順でECG波形の前処理を行ってます

  1. ECGデータを10秒間ごとのWindowに分けてそのうちの一つのWindowをとってくる
  2. 0~1に正規化する
  3. 極大値をすべて見つける
  4. 0.9以上の極大値をR-peakとする
  5. RR間隔の平均値を求め、その間隔をWindow長のTとする
  6. それぞれのR-peakから1.2Tだけデータをとる
  7. あらかじめ指定したデータ長に満たないデータを0で埋める(Zero-padding)

これら前処理済みデータセットはすでにKaggleに公開されています。

https://www.kaggle.com/shayanfazeli/heartbeat

実際のデータに学習モデルを適応するためには入力データの形式を学習時のそれと合わせる必要があります。実際に得られるECGデータはサンプリング周波数も異なるので上の前処理に加えてそこも調節する必要があります。

学習データセットは125Hzなのでまずはその周波数に従ってリサンプリングしましょう。

チュートリアル用に360HzでサンプリングされたScipyのECGデータを使います。


png

まずは125Hzでリサンプリングします。


png

ほぼ一致していることが分かります。それでは1.からみていきましょう。

  1. ECGデータを10秒間ごとのWindowに分けてそのうちの一つのWindowをとってくる

png

  1. 0~1に正規化する

png

  1. 極大値をすべて見つける
  2. 0.9以上の極大値をR-peakとする

png

このデータではThreshold=0.9というのはあまりよくないようですが、論文の通りにいきましょう。

  1. RR間隔の平均値を求め、その間隔をWindow長のTとする

  1. それぞれのR-peakから1.2Tだけデータをとる
  2. あらかじめ指定したデータ長に満たないデータを0で埋める(Zero-padding)

Shape of the extracted beats data: (4, 187)


png


png

上手く心拍を抽出できていることがわかります。

最後に以上の処理をpreprocess関数にまとめてみましょう。


Shape of the extracted beats data: (104, 187)

png

ほんの一部の心拍が抽出されていることが分かります。この論文大丈夫か心配になってきましたね。

Model

論文のモデルは以下

https://github.com/CVxTz/ECG_Heartbeat_Classification

論文にのってるモデルを少し変えたやつ(Residual blockなしバージョン)

1D-Convolution layer

"all convolution layers are applying 1-D convolution through time and each have 32 kernels of size 5"

  • カーネル:入力にかける行列のこと、今回は1次元。32 Kernalsはカーネルの層数を意味するので**Kerasだったらfilters = 32**にあたる。
  • サイズ:カーネルのWindow長。今回は一次元。size 5は**Kerasだったらkernel_size = 5**にあたる。

二次元畳み込み層よりもパラメータ数はもちろん少ない。今回は入力が一次元なので一次元畳み込み層で自然。

畳み込み層ついて詳しくは https://towardsdatascience.com/types-of-convolution-kernels-simplified-f040cb307c37

Dataset

https://www.kaggle.com/shayanfazeli/heartbeat のデータセットを使います。前処理済み最高長187の心拍がCSVファイルで格納されています。188番目のカラムにはその心拍のラベル(心室期外収縮や心筋梗塞など)が整数クラスで入ってます。

MITBIHのAnnotationは以下のようになってます。

N,S,V,F,Qはそれぞれ0,1,2,3,4クラスに対応しています。実際にデータセットを見てみましょう。


Data shape: (87554, 188) All classes (shown in 188th column): [0. 1. 2. 3. 4.]


png

論文記載のアルゴリズムに従って前処理されていることが分かります。

Training the Arrhythmia Classifier

MITBIHデータセットでまずは学習します。github借りパくです。


Model: "functional_1" _________________________________________________________________ Layer (type) Output Shape Param #
================================================================= input_1 (InputLayer) [(None, 187, 1)] 0
_________________________________________________________________ conv1d (Conv1D) (None, 183, 16) 96
_________________________________________________________________ conv1d_1 (Conv1D) (None, 179, 16) 1296
_________________________________________________________________ max_pooling1d (MaxPooling1D) (None, 89, 16) 0
_________________________________________________________________ dropout (Dropout) (None, 89, 16) 0
_________________________________________________________________ conv1d_2 (Conv1D) (None, 87, 32) 1568
_________________________________________________________________ conv1d_3 (Conv1D) (None, 85, 32) 3104
_________________________________________________________________ max_pooling1d_1 (MaxPooling1 (None, 42, 32) 0
_________________________________________________________________ dropout_1 (Dropout) (None, 42, 32) 0
_________________________________________________________________ conv1d_4 (Conv1D) (None, 40, 32) 3104
_________________________________________________________________ conv1d_5 (Conv1D) (None, 38, 32) 3104
_________________________________________________________________ max_pooling1d_2 (MaxPooling1 (None, 19, 32) 0
_________________________________________________________________ dropout_2 (Dropout) (None, 19, 32) 0
_________________________________________________________________ conv1d_6 (Conv1D) (None, 17, 256) 24832
_________________________________________________________________ conv1d_7 (Conv1D) (None, 15, 256) 196864
_________________________________________________________________ global_max_pooling1d (Global (None, 256) 0
_________________________________________________________________ dropout_3 (Dropout) (None, 256) 0
_________________________________________________________________ dense_1 (Dense) (None, 64) 16448
_________________________________________________________________ dense_2 (Dense) (None, 64) 4160
_________________________________________________________________ dense_3_mitbih (Dense) (None, 5) 325
================================================================= Total params: 254,901 Trainable params: 254,901 Non-trainable params: 0 _________________________________________________________________ Epoch 1/1000

Epoch 00001: val_acc improved from -inf to 0.92005, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.3817 - acc: 0.8841 - val_loss: 0.2877 - val_acc: 0.9201
Epoch 2/1000

Epoch 00002: val_acc improved from 0.92005 to 0.95420, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.2364 - acc: 0.9326 - val_loss: 0.1660 - val_acc: 0.9542
Epoch 3/1000

Epoch 00003: val_acc improved from 0.95420 to 0.96551, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1691 - acc: 0.9537 - val_loss: 0.1357 - val_acc: 0.9655
Epoch 4/1000

Epoch 00004: val_acc improved from 0.96551 to 0.96962, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1409 - acc: 0.9620 - val_loss: 0.1146 - val_acc: 0.9696
Epoch 5/1000

Epoch 00005: val_acc improved from 0.96962 to 0.97270, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1234 - acc: 0.9672 - val_loss: 0.1008 - val_acc: 0.9727
Epoch 6/1000

Epoch 00006: val_acc improved from 0.97270 to 0.97487, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.1114 - acc: 0.9697 - val_loss: 0.0901 - val_acc: 0.9749
Epoch 7/1000

Epoch 00007: val_acc improved from 0.97487 to 0.97602, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1013 - acc: 0.9723 - val_loss: 0.0847 - val_acc: 0.9760
Epoch 8/1000

Epoch 00008: val_acc did not improve from 0.97602
2463/2463 - 10s - loss: 0.0945 - acc: 0.9741 - val_loss: 0.0886 - val_acc: 0.9751
Epoch 9/1000

Epoch 00009: val_acc improved from 0.97602 to 0.97796, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0894 - acc: 0.9752 - val_loss: 0.0814 - val_acc: 0.9780
Epoch 10/1000

Epoch 00010: val_acc improved from 0.97796 to 0.97967, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.0840 - acc: 0.9771 - val_loss: 0.0723 - val_acc: 0.9797
Epoch 11/1000

Epoch 00011: val_acc did not improve from 0.97967
2463/2463 - 10s - loss: 0.0795 - acc: 0.9775 - val_loss: 0.0743 - val_acc: 0.9788
Epoch 12/1000

Epoch 00012: val_acc improved from 0.97967 to 0.98184, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.0772 - acc: 0.9776 - val_loss: 0.0666 - val_acc: 0.9818
Epoch 13/1000

Epoch 00013: val_acc did not improve from 0.98184
2463/2463 - 10s - loss: 0.0741 - acc: 0.9790 - val_loss: 0.0649 - val_acc: 0.9814
Epoch 14/1000

Epoch 00014: val_acc did not improve from 0.98184
2463/2463 - 11s - loss: 0.0702 - acc: 0.9798 - val_loss: 0.0660 - val_acc: 0.9802
Epoch 15/1000

Epoch 00015: val_acc improved from 0.98184 to 0.98241, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0711 - acc: 0.9795 - val_loss: 0.0708 - val_acc: 0.9824
Epoch 16/1000

Epoch 00016: val_acc did not improve from 0.98241
2463/2463 - 10s - loss: 0.0679 - acc: 0.9804 - val_loss: 0.0646 - val_acc: 0.9823
Epoch 17/1000

Epoch 00017: val_acc did not improve from 0.98241
2463/2463 - 10s - loss: 0.0656 - acc: 0.9808 - val_loss: 0.0651 - val_acc: 0.9812
Epoch 18/1000

Epoch 00018: val_acc improved from 0.98241 to 0.98344, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0641 - acc: 0.9813 - val_loss: 0.0602 - val_acc: 0.9834
Epoch 19/1000

Epoch 00019: val_acc did not improve from 0.98344
2463/2463 - 10s - loss: 0.0615 - acc: 0.9820 - val_loss: 0.0659 - val_acc: 0.9814
Epoch 20/1000

Epoch 00020: val_acc did not improve from 0.98344
2463/2463 - 10s - loss: 0.0615 - acc: 0.9820 - val_loss: 0.0592 - val_acc: 0.9828
Epoch 21/1000

Epoch 00021: val_acc did not improve from 0.98344

Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
2463/2463 - 11s - loss: 0.0599 - acc: 0.9828 - val_loss: 0.0545 - val_acc: 0.9831
Epoch 22/1000

Epoch 00022: val_acc improved from 0.98344 to 0.98561, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0449 - acc: 0.9867 - val_loss: 0.0477 - val_acc: 0.9856
Epoch 23/1000

Epoch 00023: val_acc improved from 0.98561 to 0.98584, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.0408 - acc: 0.9879 - val_loss: 0.0453 - val_acc: 0.9858
Epoch 24/1000

Epoch 00024: val_acc improved from 0.98584 to 0.98595, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0390 - acc: 0.9882 - val_loss: 0.0439 - val_acc: 0.9860
Epoch 25/1000

Epoch 00025: val_acc improved from 0.98595 to 0.98721, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0358 - acc: 0.9891 - val_loss: 0.0426 - val_acc: 0.9872
Epoch 26/1000

Epoch 00026: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0347 - acc: 0.9897 - val_loss: 0.0430 - val_acc: 0.9864
Epoch 27/1000

Epoch 00027: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0351 - acc: 0.9889 - val_loss: 0.0438 - val_acc: 0.9869
Epoch 28/1000

Epoch 00028: val_acc did not improve from 0.98721

Epoch 00028: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.
2463/2463 - 11s - loss: 0.0338 - acc: 0.9897 - val_loss: 0.0425 - val_acc: 0.9862
Epoch 29/1000

Epoch 00029: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0324 - acc: 0.9896 - val_loss: 0.0421 - val_acc: 0.9862
Epoch 30/1000

Epoch 00030: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0324 - acc: 0.9897 - val_loss: 0.0422 - val_acc: 0.9866
Epoch 00030: early stopping
Test f1 score : 0.9158830356755775 
Test accuracy score : 0.9850630367257446 

結果

  • Test f1 score : 0.9158830356755775
  • Test accuracy score : 0.9850630367257446

Training the MI Predictor

先ほどのMITBIHデータセットで得られたNNを利用して心筋梗塞2値分類タスクについてPTBDBデータセットで学習します。

論文では不整脈分類タスクのNNの最後の2層のみFine-tuningしてましたが、今回使うGithubのほうでは最後の2層以外の重みを固定するということはしないで、一緒に学習しなおすということをして実際最終2層より以前の重みをフリーズして学習するよりもスコアが良かったのでそちらを紹介します。


Discussion

コメントにはログインが必要です。