Magicode logo
Magicode
1
3 min read

アヤメの分類

はじめに

機械学習の初学の定番?のアヤメの分類をやってみました。
magicodeはipynbファイルをアップロードしてそのまま記事の載せられるようなので、試したかったというのが大きいのであまりアヤメの分類自体をやりこんではいません。
ライブラリのinstallやimportを必要になったタイミングで都度実施していてnotebookの作法的にもあまりよくないかもしれません。

sklearnインストール

今回使用するデータとモデルはscikit-learnのものを使うのでskleranをインストールします。 notebook上でpythonライブラリをinstallをするときは’!’をつけてpipコマンドを書くようです。
インストール済の場合は下記のようにRequirement already satisfiedと出ますが、なければインストールされます。

!pip install sklearn

Requirement already satisfied: sklearn in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (0.0) Requirement already satisfied: scikit-learn in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from sklearn) (1.1.1) Requirement already satisfied: scipy>=1.3.2 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from scikit-learn->sklearn) (1.8.1) Requirement already satisfied: numpy>=1.17.3 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from scikit-learn->sklearn) (1.22.4) Requirement already satisfied: threadpoolctl>=2.0.0 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from scikit-learn->sklearn) (3.1.0) Requirement already satisfied: joblib>=1.0.0 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from scikit-learn->sklearn) (1.1.0)
WARNING: You are using pip version 22.0.4; however, version 22.1.2 is available. You should consider upgrading via the 'C:\Users\mnbi\Documents\magicode\.venv\Scripts\python.exe -m pip install --upgrade pip' command.

データ読み込み、確認

skleranから今回学習に使用するアヤメのデータを読み込みます。
読み込んだら、特徴量(sepal length,sepal width,petal length,petal width)とラベル(0,1,2の3種類)を確認します。
print(iris.target)の結果は表示されませんでした。magicodeの仕様でしょうか。。。

しっかりprintされました。

#scikit-learnからデータの読み込み
from sklearn import datasets
iris = datasets.load_iris()

# アヤメの分類に使用するデータの確認
print(iris.data)
print(iris.target)

[[5.1 3.5 1.4 0.2] [4.9 3. 1.4 0.2] [4.7 3.2 1.3 0.2] [4.6 3.1 1.5 0.2] [5. 3.6 1.4 0.2] [5.4 3.9 1.7 0.4] [4.6 3.4 1.4 0.3] [5. 3.4 1.5 0.2] [4.4 2.9 1.4 0.2] [4.9 3.1 1.5 0.1] [5.4 3.7 1.5 0.2] [4.8 3.4 1.6 0.2] [4.8 3. 1.4 0.1] [4.3 3. 1.1 0.1] [5.8 4. 1.2 0.2] [5.7 4.4 1.5 0.4] [5.4 3.9 1.3 0.4] [5.1 3.5 1.4 0.3] [5.7 3.8 1.7 0.3] [5.1 3.8 1.5 0.3] [5.4 3.4 1.7 0.2] [5.1 3.7 1.5 0.4] [4.6 3.6 1. 0.2] [5.1 3.3 1.7 0.5] [4.8 3.4 1.9 0.2] [5. 3. 1.6 0.2] [5. 3.4 1.6 0.4] [5.2 3.5 1.5 0.2] [5.2 3.4 1.4 0.2] [4.7 3.2 1.6 0.2] [4.8 3.1 1.6 0.2] [5.4 3.4 1.5 0.4] [5.2 4.1 1.5 0.1] [5.5 4.2 1.4 0.2] [4.9 3.1 1.5 0.2] [5. 3.2 1.2 0.2] [5.5 3.5 1.3 0.2] [4.9 3.6 1.4 0.1] [4.4 3. 1.3 0.2] [5.1 3.4 1.5 0.2] [5. 3.5 1.3 0.3] [4.5 2.3 1.3 0.3] [4.4 3.2 1.3 0.2] [5. 3.5 1.6 0.6] [5.1 3.8 1.9 0.4] [4.8 3. 1.4 0.3] [5.1 3.8 1.6 0.2] [4.6 3.2 1.4 0.2] [5.3 3.7 1.5 0.2] [5. 3.3 1.4 0.2] [7. 3.2 4.7 1.4] [6.4 3.2 4.5 1.5] [6.9 3.1 4.9 1.5] [5.5 2.3 4. 1.3] [6.5 2.8 4.6 1.5] [5.7 2.8 4.5 1.3] [6.3 3.3 4.7 1.6] [4.9 2.4 3.3 1. ] [6.6 2.9 4.6 1.3] [5.2 2.7 3.9 1.4] [5. 2. 3.5 1. ] [5.9 3. 4.2 1.5] [6. 2.2 4. 1. ] [6.1 2.9 4.7 1.4] [5.6 2.9 3.6 1.3] [6.7 3.1 4.4 1.4] [5.6 3. 4.5 1.5] [5.8 2.7 4.1 1. ] [6.2 2.2 4.5 1.5] [5.6 2.5 3.9 1.1] [5.9 3.2 4.8 1.8] [6.1 2.8 4. 1.3] [6.3 2.5 4.9 1.5] [6.1 2.8 4.7 1.2] [6.4 2.9 4.3 1.3] [6.6 3. 4.4 1.4] [6.8 2.8 4.8 1.4] [6.7 3. 5. 1.7] [6. 2.9 4.5 1.5] [5.7 2.6 3.5 1. ] [5.5 2.4 3.8 1.1] [5.5 2.4 3.7 1. ] [5.8 2.7 3.9 1.2] [6. 2.7 5.1 1.6] [5.4 3. 4.5 1.5] [6. 3.4 4.5 1.6] [6.7 3.1 4.7 1.5] [6.3 2.3 4.4 1.3] [5.6 3. 4.1 1.3] [5.5 2.5 4. 1.3] [5.5 2.6 4.4 1.2] [6.1 3. 4.6 1.4] [5.8 2.6 4. 1.2] [5. 2.3 3.3 1. ] [5.6 2.7 4.2 1.3] [5.7 3. 4.2 1.2] [5.7 2.9 4.2 1.3] [6.2 2.9 4.3 1.3] [5.1 2.5 3. 1.1] [5.7 2.8 4.1 1.3] [6.3 3.3 6. 2.5] [5.8 2.7 5.1 1.9] [7.1 3. 5.9 2.1] [6.3 2.9 5.6 1.8] [6.5 3. 5.8 2.2] [7.6 3. 6.6 2.1] [4.9 2.5 4.5 1.7] [7.3 2.9 6.3 1.8] [6.7 2.5 5.8 1.8] [7.2 3.6 6.1 2.5] [6.5 3.2 5.1 2. ] [6.4 2.7 5.3 1.9] [6.8 3. 5.5 2.1] [5.7 2.5 5. 2. ] [5.8 2.8 5.1 2.4] [6.4 3.2 5.3 2.3] [6.5 3. 5.5 1.8] [7.7 3.8 6.7 2.2] [7.7 2.6 6.9 2.3] [6. 2.2 5. 1.5] [6.9 3.2 5.7 2.3] [5.6 2.8 4.9 2. ] [7.7 2.8 6.7 2. ] [6.3 2.7 4.9 1.8] [6.7 3.3 5.7 2.1] [7.2 3.2 6. 1.8] [6.2 2.8 4.8 1.8] [6.1 3. 4.9 1.8] [6.4 2.8 5.6 2.1] [7.2 3. 5.8 1.6] [7.4 2.8 6.1 1.9] [7.9 3.8 6.4 2. ] [6.4 2.8 5.6 2.2] [6.3 2.8 5.1 1.5] [6.1 2.6 5.6 1.4] [7.7 3. 6.1 2.3] [6.3 3.4 5.6 2.4] [6.4 3.1 5.5 1.8] [6. 3. 4.8 1.8] [6.9 3.1 5.4 2.1] [6.7 3.1 5.6 2.4] [6.9 3.1 5.1 2.3] [5.8 2.7 5.1 1.9] [6.8 3.2 5.9 2.3] [6.7 3.3 5.7 2.5] [6.7 3. 5.2 2.3] [6.3 2.5 5. 1.9] [6.5 3. 5.2 2. ] [6.2 3.4 5.4 2.3] [5.9 3. 5.1 1.8]] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]

学習実行

早速学習を実行します。
モデルはGradientBoostingClassifierを選択しました。
記事編集中に実行ボタン(▶)を押したらError出ちゃいましたが直し方わからずそのままです。

# アヤメの分類の学習

# 学習用データと検証用データに分割
from sklearn.model_selection import train_test_split as split
x_train, x_test, y_train, y_test = split(iris.data,iris.target,train_size=0.8,test_size=0.2)

# model作成
from sklearn.ensemble import GradientBoostingClassifier
model = GradientBoostingClassifier()

# 学習実行
model.fit(x_train, y_train)

GradientBoostingClassifier(ccp_alpha=0.0, criterion='friedman_mse', init=None, learning_rate=0.1, loss='deviance', max_depth=3, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=100, n_iter_no_change=None, presort='deprecated', random_state=None, subsample=1.0, tol=0.0001, validation_fraction=0.1, verbose=0, warm_start=False)

モデルの評価

学習が終わったので、test用のデータを使って予測してみます。
予測実行後、正解率と混同行列で結果を確認しました。

# modelの評価
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
# 評価の実行
y_pred = model.predict(x_test)

# 正解率
Accuracy = accuracy_score(y_test, y_pred)
print('正解率:', Accuracy)

# 
cm = confusion_matrix(y_test, y_pred)

正解率: 0.9666666666666667 [[10 0 0] [ 0 13 0] [ 0 1 6]]

分類問題でよく見るビジュアルでの表示も試します。 matplotlibを追加でインストールしました。

!pip install matplotlib
import matplotlib.pyplot as plt

disp = ConfusionMatrixDisplay(cm)
disp.plot()
plt.show()

Requirement already satisfied: matplotlib in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (3.5.2) Requirement already satisfied: pillow>=6.2.0 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from matplotlib) (9.1.1) Requirement already satisfied: fonttools>=4.22.0 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from matplotlib) (4.33.3) Requirement already satisfied: cycler>=0.10 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from matplotlib) (0.11.0) Requirement already satisfied: numpy>=1.17 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from matplotlib) (1.22.4) Requirement already satisfied: pyparsing>=2.2.1 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from matplotlib) (3.0.9) Requirement already satisfied: packaging>=20.0 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from matplotlib) (21.3) Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from matplotlib) (1.4.2) Requirement already satisfied: python-dateutil>=2.7 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from matplotlib) (2.8.2) Requirement already satisfied: six>=1.5 in c:\users\mnbi\documents\magicode\.venv\lib\site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
WARNING: You are using pip version 22.0.4; however, version 22.1.2 is available. You should consider upgrading via the 'C:\Users\mnbi\Documents\magicode\.venv\Scripts\python.exe -m pip install --upgrade pip' command.
<Figure size 432x288 with 2 Axes>

Discussion

コメントにはログインが必要です。