Magicode logo
Magicode
0
12 min read

数行のコードで機械学習ができる PyCaret のチュートリアルを試してみた 回帰分析 その1

前回の記事に引き続き、 PyCaret のチュートリアルを試してみたいと思います。 今回は回帰分析の Regression Tutorial (REG101) - Level Beginnerに沿って試してみたいと思います。
日本語版もありましたので、こちらから引用しつつ実行していきたいと思います。

PyCaret とは

PyCaret is an open-source, low-code machine learning library in Python that automates machine learning workflows.

https://pycaret.org/

と公式に記載あるように、わずかなコード量で実装できる Python の機械学習のライブラリです。

回帰分析のチュートリアル

回帰分析とは

回帰分析とは、予測したいターゲットを複数の特徴量・予測因子から推定するための分析手法です。
機械学習における回帰の目的は、売上高、数量、温度などの連続値を予測することです。

このチュートリアルでは、以下のことを学びます。

  • Getting Data: PyCaret リポジトリからデータをインポートする方法。
  • Setting up Environment: PyCaretで実験を行い、回帰モデルの構築を開始する方法
  • Create Model: モデルを作成し、クロスバリデーションを行い、回帰メトリクスを評価する方法
  • Tune Model: 回帰モデルのハイパーパラメータを自動的にチューニングする方法
  • Plot Model: 様々なプロットを使用してモデルのパフォーマンスを分析する方法
  • Finalize Model: 実験の最後に最適なモデルを最終決定する方法
  • Predict Model: 新しい/未見のデータに対して予測を行う方法
  • Save / Load Model: 今後の使用のためにモデルを保存/ロードする方法

それでは、チュートリアルに沿って進めていきたいと思います。
Magicode 上で進めますが、 一部のグラフ表示ができない(plolty が magicode で動作する方法がわからない)ので、手元で試す場合には、 colab など別の環境での実施をおすすめします。

PyCaret のインストール

!pip install pycaret

Collecting pycaret
Downloading pycaret-2.3.10-py3-none-any.whl (320 kB) [?25l |█ | 10 kB 17.1 MB/s eta 0:00:01 |██ | 20 kB 13.4 MB/s eta 0:00:01 |███ | 30 kB 11.2 MB/s eta 0:00:01 |████ | 40 kB 13.0 MB/s eta 0:00:01 |█████▏ | 51 kB 7.0 MB/s eta 0:00:01 |██████▏ | 61 kB 8.1 MB/s eta 0:00:01
|███████▏ | 71 kB 7.3 MB/s eta 0:00:01 |████████▏ | 81 kB 6.6 MB/s eta 0:00:01 |█████████▏ | 92 kB 7.3 MB/s eta 0:00:01 |██████████▎ | 102 kB 6.6 MB/s eta 0:00:01 |███████████▎ | 112 kB 6.6 MB/s eta 0:00:01 |████████████▎ | 122 kB 6.6 MB/s eta 0:00:01 |█████████████▎ | 133 kB 6.6 MB/s eta 0:00:01 |██████████████▎ | 143 kB 6.6 MB/s eta 0:00:01 |███████████████▍ | 153 kB 6.6 MB/s eta 0:00:01 |████████████████▍ | 163 kB 6.6 MB/s eta 0:00:01 |█████████████████▍ | 174 kB 6.6 MB/s eta 0:00:01 |██████████████████▍ | 184 kB 6.6 MB/s eta 0:00:01 |███████████████████▍ | 194 kB 6.6 MB/s eta 0:00:01 |████████████████████▌ | 204 kB 6.6 MB/s eta 0:00:01 |█████████████████████▌ | 215 kB 6.6 MB/s eta 0:00:01 |██████████████████████▌ | 225 kB 6.6 MB/s eta 0:00:01 |███████████████████████▌ | 235 kB 6.6 MB/s eta 0:00:01 |████████████████████████▋ | 245 kB 6.6 MB/s eta 0:00:01 |█████████████████████████▋ | 256 kB 6.6 MB/s eta 0:00:01 |██████████████████████████▋ | 266 kB 6.6 MB/s eta 0:00:01 |███████████████████████████▋ | 276 kB 6.6 MB/s eta 0:00:01 |████████████████████████████▋ | 286 kB 6.6 MB/s eta 0:00:01 |█████████████████████████████▊ | 296 kB 6.6 MB/s eta 0:00:01 |██████████████████████████████▊ | 307 kB 6.6 MB/s eta 0:00:01 |███████████████████████████████▊| 317 kB 6.6 MB/s eta 0:00:01 |████████████████████████████████| 320 kB 6.6 MB/s [?25h
Requirement already satisfied: IPython in /srv/conda/envs/notebook/lib/python3.7/site-packages (from pycaret) (7.31.1)
Collecting scikit-learn==0.23.2 Downloading scikit_learn-0.23.2-cp37-cp37m-manylinux1_x86_64.whl (6.8 MB) [?25l | | 10 kB 33.7 MB/s eta 0:00:01 | | 20 kB 41.4 MB/s eta 0:00:01 |▏ | 30 kB 50.3 MB/s eta 0:00:01 |▏ | 40 kB 55.8 MB/s eta 0:00:01 |▎ | 51 kB 59.6 MB/s eta 0:00:01 |▎ | 61 kB 64.4 MB/s eta 0:00:01 |▍ | 71 kB 33.1 MB/s eta 0:00:01 |▍ | 81 kB 36.1 MB/s eta 0:00:01 |▍ | 92 kB 36.3 MB/s eta 0:00:01 |▌ | 102 kB 35.2 MB/s eta 0:00:01 |▌ | 112 kB 35.2 MB/s eta 0:00:01 |▋ | 122 kB 35.2 MB/s eta 0:00:01 |▋ | 133 kB 35.2 MB/s eta 0:00:01 |▊ | 143 kB 35.2 MB/s eta 0:00:01 |▊ | 153 kB 35.2 MB/s eta 0:00:01 |▊ | 163 kB 35.2 MB/s eta 0:00:01 |▉ | 174 kB 35.2 MB/s eta 0:00:01 |▉ | 184 kB 35.2 MB/s eta 0:00:01 |█ | 194 kB 35.2 MB/s eta 0:00:01 |█ | 204 kB 35.2 MB/s eta 0:00:01 |█ | 215 kB 35.2 MB/s eta 0:00:01 |█ | 225 kB 35.2 MB/s eta 0:00:01 |█ | 235 kB 35.2 MB/s eta 0:00:01 |█▏ | 245 kB 35.2 MB/s eta 0:00:01 |█▏ | 256 kB 35.2 MB/s eta 0:00:01 |█▎ | 266 kB 35.2 MB/s eta 0:00:01 |█▎ | 276 kB 35.2 MB/s eta 0:00:01 |█▍ | 286 kB 35.2 MB/s eta 0:00:01 |█▍ | 296 kB 35.2 MB/s eta 0:00:01 |█▍ | 307 kB 35.2 MB/s eta 0:00:01 |█▌ | 317 kB 35.2 MB/s eta 0:00:01 |█▌ | 327 kB 35.2 MB/s eta 0:00:01 |█▋ | 337 kB 35.2 MB/s eta 0:00:01 |█▋ | 348 kB 35.2 MB/s eta 0:00:01 |█▊ | 358 kB 35.2 MB/s eta 0:00:01 |█▊ | 368 kB 35.2 MB/s eta 0:00:01 |█▊ | 378 kB 35.2 MB/s eta 0:00:01 |█▉ | 389 kB 35.2 MB/s eta 0:00:01 |█▉ | 399 kB 35.2 MB/s eta 0:00:01 |██ | 409 kB 35.2 MB/s eta 0:00:01 |██ | 419 kB 35.2 MB/s eta 0:00:01 |██ | 430 kB 35.2 MB/s eta 0:00:01 |██ | 440 kB 35.2 MB/s eta 0:00:01 |██▏ | 450 kB 35.2 MB/s eta 0:00:01 |██▏ | 460 kB 35.2 MB/s eta 0:00:01 |██▏ | 471 kB 35.2 MB/s eta 0:00:01 |██▎ | 481 kB 35.2 MB/s eta 0:00:01 |██▎ | 491 kB 35.2 MB/s eta 0:00:01 |██▍ | 501 kB 35.2 MB/s eta 0:00:01 |██▍ | 512 kB 35.2 MB/s eta 0:00:01 |██▌ | 522 kB 35.2 MB/s eta 0:00:01 |██▌ | 532 kB 35.2 MB/s eta 0:00:01
|██▌ | 542 kB 35.2 MB/s eta 0:00:01 |██▋ | 552 kB 35.2 MB/s eta 0:00:01 |██▋ | 563 kB 35.2 MB/s eta 0:00:01 |██▊ | 573 kB 35.2 MB/s eta 0:00:01 |██▊ | 583 kB 35.2 MB/s eta 0:00:01 |██▉ | 593 kB 35.2 MB/s eta 0:00:01 |██▉ | 604 kB 35.2 MB/s eta 0:00:01 |██▉ | 614 kB 35.2 MB/s eta 0:00:01 |███ | 624 kB 35.2 MB/s eta 0:00:01 |███ | 634 kB 35.2 MB/s eta 0:00:01 |███ | 645 kB 35.2 MB/s eta 0:00:01 |███ | 655 kB 35.2 MB/s eta 0:00:01 |███▏ | 665 kB 35.2 MB/s eta 0:00:01 |███▏ | 675 kB 35.2 MB/s eta 0:00:01 |███▏ | 686 kB 35.2 MB/s eta 0:00:01 |███▎ | 696 kB 35.2 MB/s eta 0:00:01 |███▎ | 706 kB 35.2 MB/s eta 0:00:01 |███▍ | 716 kB 35.2 MB/s eta 0:00:01 |███▍ | 727 kB 35.2 MB/s eta 0:00:01 |███▌ | 737 kB 35.2 MB/s eta 0:00:01 |███▌ | 747 kB 35.2 MB/s eta 0:00:01 |███▌ | 757 kB 35.2 MB/s eta 0:00:01 |███▋ | 768 kB 35.2 MB/s eta 0:00:01 |███▋ | 778 kB 35.2 MB/s eta 0:00:01 |███▊ | 788 kB 35.2 MB/s eta 0:00:01 |███▊ | 798 kB 35.2 MB/s eta 0:00:01 |███▉ | 808 kB 35.2 MB/s eta 0:00:01 |███▉ | 819 kB 35.2 MB/s eta 0:00:01 |███▉ | 829 kB 35.2 MB/s eta 0:00:01 |████ | 839 kB 35.2 MB/s eta 0:00:01 |████ | 849 kB 35.2 MB/s eta 0:00:01 |████ | 860 kB 35.2 MB/s eta 0:00:01 |████ | 870 kB 35.2 MB/s eta 0:00:01 |████▏ | 880 kB 35.2 MB/s eta 0:00:01 |████▏ | 890 kB 35.2 MB/s eta 0:00:01 |████▎ | 901 kB 35.2 MB/s eta 0:00:01 |████▎ | 911 kB 35.2 MB/s eta 0:00:01 |████▎ | 921 kB 35.2 MB/s eta 0:00:01 |████▍ | 931 kB 35.2 MB/s eta 0:00:01 |████▍ | 942 kB 35.2 MB/s eta 0:00:01 |████▌ | 952 kB 35.2 MB/s eta 0:00:01 |████▌ | 962 kB 35.2 MB/s eta 0:00:01 |████▋ | 972 kB 35.2 MB/s eta 0:00:01 |████▋ | 983 kB 35.2 MB/s eta 0:00:01 |████▋ | 993 kB 35.2 MB/s eta 0:00:01 |████▊ | 1.0 MB 35.2 MB/s eta 0:00:01 |████▊ | 1.0 MB 35.2 MB/s eta 0:00:01 |████▉ | 1.0 MB 35.2 MB/s eta 0:00:01 |████▉ | 1.0 MB 35.2 MB/s eta 0:00:01 |█████ | 1.0 MB 35.2 MB/s eta 0:00:01 |█████ | 1.1 MB 35.2 MB/s eta 0:00:01 |█████ | 1.1 MB 35.2 MB/s eta 0:00:01 |█████ | 1.1 MB 35.2 MB/s eta 0:00:01 |█████ | 1.1 MB 35.2 MB/s eta 0:00:01 |█████▏ | 1.1 MB 35.2 MB/s eta 0:00:01 |█████▏ | 1.1 MB 35.2 MB/s eta 0:00:01 |█████▎ | 1.1 MB 35.2 MB/s eta 0:00:01 |█████▎ | 1.1 MB 35.2 MB/s eta 0:00:01 |█████▎ | 1.1 MB 35.2 MB/s eta 0:00:01 |█████▍ | 1.1 MB 35.2 MB/s eta 0:00:01 |█████▍ | 1.2 MB 35.2 MB/s eta 0:00:01 |█████▌ | 1.2 MB 35.2 MB/s eta 0:00:01 |█████▌ | 1.2 MB 35.2 MB/s eta 0:00:01 |█████▋ | 1.2 MB 35.2 MB/s eta 0:00:01 |█████▋ | 1.2 MB 35.2 MB/s eta 0:00:01 |█████▋ | 1.2 MB 35.2 MB/s eta 0:00:01 |█████▊ | 1.2 MB 35.2 MB/s eta 0:00:01 |█████▊ | 1.2 MB 35.2 MB/s eta 0:00:01 |█████▉ | 1.2 MB 35.2 MB/s eta 0:00:01 |█████▉ | 1.2 MB 35.2 MB/s eta 0:00:01 |██████ | 1.3 MB 35.2 MB/s eta 0:00:01 |██████ | 1.3 MB 35.2 MB/s eta 0:00:01 |██████ | 1.3 MB 35.2 MB/s eta 0:00:01 |██████ | 1.3 MB 35.2 MB/s eta 0:00:01 |██████ | 1.3 MB 35.2 MB/s eta 0:00:01 |██████▏ | 1.3 MB 35.2 MB/s eta 0:00:01 |██████▏ | 1.3 MB 35.2 MB/s eta 0:00:01 |██████▎ | 1.3 MB 35.2 MB/s eta 0:00:01 |██████▎ | 1.3 MB 35.2 MB/s eta 0:00:01 |██████▍ | 1.4 MB 35.2 MB/s eta 0:00:01 |██████▍ | 1.4 MB 35.2 MB/s eta 0:00:01 |██████▍ | 1.4 MB 35.2 MB/s eta 0:00:01 |██████▌ | 1.4 MB 35.2 MB/s eta 0:00:01
|██████▌ | 1.4 MB 35.2 MB/s eta 0:00:01 |██████▋ | 1.4 MB 35.2 MB/s eta 0:00:01 |██████▋ | 1.4 MB 35.2 MB/s eta 0:00:01 |██████▊ | 1.4 MB 35.2 MB/s eta 0:00:01 |██████▊ | 1.4 MB 35.2 MB/s eta 0:00:01 |██████▊ | 1.4 MB 35.2 MB/s eta 0:00:01 |██████▉ | 1.5 MB 35.2 MB/s eta 0:00:01 |██████▉ | 1.5 MB 35.2 MB/s eta 0:00:01 |███████ | 1.5 MB 35.2 MB/s eta 0:00:01 |███████ | 1.5 MB 35.2 MB/s eta 0:00:01 |███████ | 1.5 MB 35.2 MB/s eta 0:00:01 |███████ | 1.5 MB 35.2 MB/s eta 0:00:01 |███████ | 1.5 MB 35.2 MB/s eta 0:00:01 |███████▏ | 1.5 MB 35.2 MB/s eta 0:00:01 |███████▏ | 1.5 MB 35.2 MB/s eta 0:00:01 |███████▎ | 1.5 MB 35.2 MB/s eta 0:00:01 |███████▎ | 1.6 MB 35.2 MB/s eta 0:00:01 |███████▍ | 1.6 MB 35.2 MB/s eta 0:00:01 |███████▍ | 1.6 MB 35.2 MB/s eta 0:00:01 |███████▍ | 1.6 MB 35.2 MB/s eta 0:00:01 |███████▌ | 1.6 MB 35.2 MB/s eta 0:00:01 |███████▌ | 1.6 MB 35.2 MB/s eta 0:00:01 |███████▋ | 1.6 MB 35.2 MB/s eta 0:00:01 |███████▋ | 1.6 MB 35.2 MB/s eta 0:00:01 |███████▊ | 1.6 MB 35.2 MB/s eta 0:00:01 |███████▊ | 1.6 MB 35.2 MB/s eta 0:00:01 |███████▊ | 1.7 MB 35.2 MB/s eta 0:00:01 |███████▉ | 1.7 MB 35.2 MB/s eta 0:00:01 |███████▉ | 1.7 MB 35.2 MB/s eta 0:00:01 |████████ | 1.7 MB 35.2 MB/s eta 0:00:01 |████████ | 1.7 MB 35.2 MB/s eta 0:00:01 |████████ | 1.7 MB 35.2 MB/s eta 0:00:01 |████████ | 1.7 MB 35.2 MB/s eta 0:00:01 |████████▏ | 1.7 MB 35.2 MB/s eta 0:00:01 |████████▏ | 1.7 MB 35.2 MB/s eta 0:00:01 |████████▏ | 1.8 MB 35.2 MB/s eta 0:00:01 |████████▎ | 1.8 MB 35.2 MB/s eta 0:00:01 |████████▎ | 1.8 MB 35.2 MB/s eta 0:00:01 |████████▍ | 1.8 MB 35.2 MB/s eta 0:00:01 |████████▍ | 1.8 MB 35.2 MB/s eta 0:00:01 |████████▌ | 1.8 MB 35.2 MB/s eta 0:00:01 |████████▌ | 1.8 MB 35.2 MB/s eta 0:00:01 |████████▌ | 1.8 MB 35.2 MB/s eta 0:00:01 |████████▋ | 1.8 MB 35.2 MB/s eta 0:00:01 |████████▋ | 1.8 MB 35.2 MB/s eta 0:00:01 |████████▊ | 1.9 MB 35.2 MB/s eta 0:00:01 |████████▊ | 1.9 MB 35.2 MB/s eta 0:00:01 |████████▉ | 1.9 MB 35.2 MB/s eta 0:00:01 |████████▉ | 1.9 MB 35.2 MB/s eta 0:00:01 |████████▉ | 1.9 MB 35.2 MB/s eta 0:00:01 |█████████ | 1.9 MB 35.2 MB/s eta 0:00:01 |█████████ | 1.9 MB 35.2 MB/s eta 0:00:01 |█████████ | 1.9 MB 35.2 MB/s eta 0:00:01 |█████████ | 1.9 MB 35.2 MB/s eta 0:00:01 |█████████▏ | 1.9 MB 35.2 MB/s eta 0:00:01 |█████████▏ | 2.0 MB 35.2 MB/s eta 0:00:01 |█████████▏ | 2.0 MB 35.2 MB/s eta 0:00:01 |█████████▎ | 2.0 MB 35.2 MB/s eta 0:00:01 |█████████▎ | 2.0 MB 35.2 MB/s eta 0:00:01 |█████████▍ | 2.0 MB 35.2 MB/s eta 0:00:01 |█████████▍ | 2.0 MB 35.2 MB/s eta 0:00:01 |█████████▌ | 2.0 MB 35.2 MB/s eta 0:00:01 |█████████▌ | 2.0 MB 35.2 MB/s eta 0:00:01
|█████████▌ | 2.0 MB 35.2 MB/s eta 0:00:01 |█████████▋ | 2.0 MB 35.2 MB/s eta 0:00:01 |█████████▋ | 2.1 MB 35.2 MB/s eta 0:00:01 |█████████▊ | 2.1 MB 35.2 MB/s eta 0:00:01 |█████████▊ | 2.1 MB 35.2 MB/s eta 0:00:01 |█████████▉ | 2.1 MB 35.2 MB/s eta 0:00:01 |█████████▉ | 2.1 MB 35.2 MB/s eta 0:00:01 |██████████ | 2.1 MB 35.2 MB/s eta 0:00:01 |██████████ | 2.1 MB 35.2 MB/s eta 0:00:01 |██████████ | 2.1 MB 35.2 MB/s eta 0:00:01 |██████████ | 2.1 MB 35.2 MB/s eta 0:00:01 |██████████ | 2.2 MB 35.2 MB/s eta 0:00:01 |██████████▏ | 2.2 MB 35.2 MB/s eta 0:00:01 |██████████▏ | 2.2 MB 35.2 MB/s eta 0:00:01 |██████████▎ | 2.2 MB 35.2 MB/s eta 0:00:01 |██████████▎ | 2.2 MB 35.2 MB/s eta 0:00:01 |██████████▎ | 2.2 MB 35.2 MB/s eta 0:00:01 |██████████▍ | 2.2 MB 35.2 MB/s eta 0:00:01 |██████████▍ | 2.2 MB 35.2 MB/s eta 0:00:01 |██████████▌ | 2.2 MB 35.2 MB/s eta 0:00:01 |██████████▌ | 2.2 MB 35.2 MB/s eta 0:00:01 |██████████▋ | 2.3 MB 35.2 MB/s eta 0:00:01 |██████████▋ | 2.3 MB 35.2 MB/s eta 0:00:01 |██████████▋ | 2.3 MB 35.2 MB/s eta 0:00:01 |██████████▊ | 2.3 MB 35.2 MB/s eta 0:00:01 |██████████▊ | 2.3 MB 35.2 MB/s eta 0:00:01 |██████████▉ | 2.3 MB 35.2 MB/s eta 0:00:01 |██████████▉ | 2.3 MB 35.2 MB/s eta 0:00:01 |███████████ | 2.3 MB 35.2 MB/s eta 0:00:01 |███████████ | 2.3 MB 35.2 MB/s eta 0:00:01 |███████████ | 2.3 MB 35.2 MB/s eta 0:00:01 |███████████ | 2.4 MB 35.2 MB/s eta 0:00:01 |███████████ | 2.4 MB 35.2 MB/s eta 0:00:01 |███████████▏ | 2.4 MB 35.2 MB/s eta 0:00:01 |███████████▏ | 2.4 MB 35.2 MB/s eta 0:00:01 |███████████▎ | 2.4 MB 35.2 MB/s eta 0:00:01 |███████████▎ | 2.4 MB 35.2 MB/s eta 0:00:01 |███████████▎ | 2.4 MB 35.2 MB/s eta 0:00:01 |███████████▍ | 2.4 MB 35.2 MB/s eta 0:00:01 |███████████▍ | 2.4 MB 35.2 MB/s eta 0:00:01 |███████████▌ | 2.4 MB 35.2 MB/s eta 0:00:01 |███████████▌ | 2.5 MB 35.2 MB/s eta 0:00:01 |███████████▋ | 2.5 MB 35.2 MB/s eta 0:00:01 |███████████▋ | 2.5 MB 35.2 MB/s eta 0:00:01 |███████████▋ | 2.5 MB 35.2 MB/s eta 0:00:01 |███████████▊ | 2.5 MB 35.2 MB/s eta 0:00:01 |███████████▊ | 2.5 MB 35.2 MB/s eta 0:00:01 |███████████▉ | 2.5 MB 35.2 MB/s eta 0:00:01 |███████████▉ | 2.5 MB 35.2 MB/s eta 0:00:01 |████████████ | 2.5 MB 35.2 MB/s eta 0:00:01 |████████████ | 2.5 MB 35.2 MB/s eta 0:00:01 |████████████ | 2.6 MB 35.2 MB/s eta 0:00:01 |████████████ | 2.6 MB 35.2 MB/s eta 0:00:01 |████████████ | 2.6 MB 35.2 MB/s eta 0:00:01 |████████████▏ | 2.6 MB 35.2 MB/s eta 0:00:01 |████████████▏ | 2.6 MB too many strings

colab では、以下の様にインストールしました。

!pip install pycaret==2.3.10 markupsafe==2.0.1 pyyaml==5.4.1 -qq

後続の from pycaret.classification import * の部分でエラーが発生してしまったので、 stack overflowの記事を参考にしました。

チュートリアルで使用するデータセットの説明

このチュートリアルでは、"Sarah Gets a Diamond "というケーススタディに基づいたデータセットを使います。このケースは、Darden School of Business (University of Virginia)の1年生の意思決定分析のコースで紹介されました。データのベースとなっているのは、絶望的なロマンチストのMBA学生が、花嫁となるサラにふさわしいダイヤモンドを選ぶというケースです。データにはトレーニング用の6000レコードが含まれています。各カラムの簡単な説明は以下の通りです。

  • ID:** 各観測データ(ダイヤモンド)を一意に識別する。
  • Carat Weight:** ダイヤモンドの重量をメートル法で表したもの。1カラットは0.2グラムに相当し、ペーパークリップとほぼ同じ重さである。
  • Cut:**ダイヤモンドのカットを示す5つの値のうちの1つで、望ましい順に並べたもの(Signature-Ideal、Ideal、Very Good、Good、Fair)。
  • Color:**ダイヤモンドのカラーを示す6つの値のうち、望ましい順に1つを選択(D、E、F-無色、G、H、I-無色に近い)。
  • Clarity:**ダイヤモンドのクラリティを示す7つの値のうち、望ましい順に1つ(F - Flawless、IF - Internally Flawless、VVS1またはVVS2 - Very, Very Slightly Included、またはVS1またはVS2 - Very Slightly Included、SI1 - Slightly Included)。
  • Polish:**ダイヤモンドの研磨状態を示す4つの値のうちの1つ(ID - Ideal、EX - Excellent、VG - Very Good、G - Good)。
  • Symmetry:** ダイヤモンドのシンメトリーを示す4つの値のうちの1つ(ID - Ideal, EX - Excellent, VG - Very Good, G - Good)。
  • Report:** ダイヤモンドの品質を報告したグレーディング機関を示す2つの値「AGSL」または「GIA」のうちの1つ。
  • Price: ダイヤモンドの評価額を米ドルで表したもの Target Column です。

データの取得

get_data() を使いデータを取得します。

from pycaret.datasets import get_data
dataset = get_data('diamond')

Carat Weight Cut Color Clarity Polish Symmetry Report Price
0 1.10 Ideal H SI1 VG EX GIA 5169
1 0.83 Ideal H VS1 ID ID AGSL 3470
2 0.85 Ideal H SI1 EX EX GIA 3183
3 0.91 Ideal E SI1 VG VG GIA 4370
4 0.83 Ideal G SI1 EX EX GIA 3171

取得したデータセットをモデル作成用( data )と予測用( data_unseen )とに90:10に分けます。

data = dataset.sample(frac=0.9, random_state=786)
data_unseen = dataset.drop(data.index)

data.reset_index(drop=True, inplace=True)
data_unseen.reset_index(drop=True, inplace=True)

print('Data for Modeling: ' + str(data.shape))
print('Unseen Data For Predictions: ' + str(data_unseen.shape))

Data for Modeling: (5400, 8) Unseen Data For Predictions: (600, 8)

セットアップ

setup() を使ってセットアップをします。

実行すると、投入したデータの型のチェックを自動で実施してくれます。 問題なければ、カーソルをあわせ、 Enter を押します。

setup()関数はpycaretの環境を初期化し、モデリングやデプロイメントのためにデータを準備する変換パイプラインを作成します。

from pycaret.regression import *
exp_reg101 = setup(data = data, target = 'Price', session_id=123)

Description Value
0 session_id 123
1 Target Price
2 Original Data (5400, 8)
3 Missing Values False
4 Numeric Features 1
5 Categorical Features 6
6 Ordinal Features False
7 High Cardinality Features False
8 High Cardinality Method None
9 Transformed Train Set (3779, 28)
10 Transformed Test Set (1621, 28)
11 Shuffle Train-Test True
12 Stratify Train-Test False
13 Fold Generator KFold
14 Fold Number 10
15 CPU Jobs -1
16 Use GPU False
17 Log Experiment False
18 Experiment Name reg-default-name
19 USI 9e27
20 Imputation Type simple
21 Iterative Imputation Iteration None
22 Numeric Imputer mean
23 Iterative Imputation Numeric Model None
24 Categorical Imputer constant
25 Iterative Imputation Categorical Model None
26 Unknown Categoricals Handling least_frequent
27 Normalize False
28 Normalize Method None
29 Transformation False
30 Transformation Method None
31 PCA False
32 PCA Method None
33 PCA Components None
34 Ignore Low Variance False
35 Combine Rare Levels False
36 Rare Level Threshold None
37 Numeric Binning False
38 Remove Outliers False
39 Outliers Threshold None
40 Remove Multicollinearity False
41 Multicollinearity Threshold None
42 Remove Perfect Collinearity True
43 Clustering False
44 Clustering Iteration None
45 Polynomial Features False
46 Polynomial Degree None
47 Trignometry Features False
48 Polynomial Threshold None
49 Group Features False
50 Feature Selection False
51 Feature Selection Method classic
52 Features Selection Threshold None
53 Feature Interaction False
54 Feature Ratio False
55 Interaction Threshold None
56 Transform Target False
57 Transform Target Method box-cox

セットアップが正常に実行されると,いくつかの重要な情報を含む情報グリッドが表示されます。ほとんどの情報は、setup()の実行時に構築される前処理パイプラインに関連しています。これらの機能の大部分は、このチュートリアルの目的からは外れています。しかし、この段階で注意すべきいくつかの重要な点があります。

  • session_id : 後の再現性のために、すべての関数でシードとして配布される疑似乱数です。もしsession_idが渡されない場合は、自動的に乱数が生成され、すべての関数に配布されます。この実験では、後の再現性のために、session_idを123としています。
  • Original Data : データセットの元の形を表示します。 元のデータ :元のデータの形を表示します。
  • Missing Values : 元のデータに欠損値がある場合は、Trueと表示されます。この実験では、データセットに欠損値はありません。
  • Numeric Features : 数値として推定された特徴の数です。このデータセットでは、8つの特徴のうち1つが数値として推論されています。
  • Categorical Features : カテゴライズされた特徴の数。このデータセットでは、8つの特徴のうち6つがカテゴライズされています。
  • Transformed Train Set : 変換後のトレーニングセットの形状を表示します。元の形状である(5400, 8)が、変換後の訓練セットでは(3779, 28)に変換されていることに注目してください。カテゴリーエンコーディングにより、特徴量の数が28から8に増えています 。
  • Transformed Test Set : 変換されたテストセット/ホールドアウトセットの形状を表示します。test/hold-out setには1621個のサンプルがあります。この分割は、デフォルトの70/30に基づいていますが、セットアップのtrain_sizeパラメータで変更することができます。

欠損値のインプテーション(この場合、トレーニングデータには欠損値はありませんが、見たことのないデータのインプテーションが必要です)やカテゴリーエンコーディングなど、モデリングを行う上で必須となるいくつかのタスクが自動的に処理されていることに注目してください。

モデルの比較

setup の次は compare_models を使い、主要なモデルのパフォーマンスを比較していきます。

この関数は、モデルライブラリ内のすべてのモデルを学習し、k-foldクロスバリデーションを用いてスコアリングを行い、メトリクス評価を行います。出力は、フォールド(デフォルトでは10)間の平均MAE、MSE、RMSE、R2、RMSLE、MAPEを学習時間とともに示すスコアグリッドを表示します。

best = compare_models(exclude = ['ransac'])
#exclude パラメータ使用し、特定のモデル(ここでは RANSAC )を除外

Model MAE MSE RMSE R2 RMSLE MAPE TT (Sec)
et Extra Trees Regressor 762.0118 2763999.1585 1612.2410 0.9729 0.0817 0.0607 1.2540
rf Random Forest Regressor 760.6304 2929683.1860 1663.0148 0.9714 0.0818 0.0597 1.1470
lightgbm Light Gradient Boosting Machine 752.6446 3056347.8515 1687.9907 0.9711 0.0773 0.0567 0.0590
gbr Gradient Boosting Regressor 920.2913 3764303.9252 1901.1793 0.9633 0.1024 0.0770 0.1970
dt Decision Tree Regressor 1003.1237 5305620.3379 2228.7271 0.9476 0.1083 0.0775 0.0260
ridge Ridge Regression 2413.5700 14120502.3164 3726.1654 0.8621 0.6689 0.2875 0.0100
lasso Lasso Regression 2412.1922 14246798.1211 3744.2305 0.8608 0.6767 0.2866 0.0310
llar Lasso Least Angle Regression 2355.6152 14272019.9688 3745.3094 0.8607 0.6391 0.2728 0.0110
br Bayesian Ridge 2415.8031 14270771.8397 3746.9951 0.8606 0.6696 0.2873 0.0140
lr Linear Regression 2418.7036 14279370.2389 3748.9580 0.8604 0.6690 0.2879 0.2430
huber Huber Regressor 1936.1466 18599231.4697 4252.8758 0.8209 0.4333 0.1657 0.0650
par Passive Aggressive Regressor 1944.1634 19955672.9330 4400.2133 0.8083 0.4317 0.1594 0.0360
omp Orthogonal Matching Pursuit 2792.7313 23728653.8040 4829.3170 0.7678 0.5819 0.2654 0.0100
ada AdaBoost Regressor 4232.2217 25201423.0703 5012.4175 0.7467 0.5102 0.5970 0.1530
knn K Neighbors Regressor 2968.0750 29627913.0479 5421.7241 0.7051 0.3664 0.2730 0.0860
en Elastic Net 5029.5913 56399795.8780 7467.6598 0.4472 0.5369 0.5845 0.0110
dummy Dummy Regressor 7280.3308 101221941.4046 10032.1624 -0.0014 0.7606 0.8969 0.0110
lar Least Angle Regression 971102106.7611 17351772854221176832.0000 1317264711.5451 -216422608857.9012 1.9525 145144.5820 0.0150

クロスバリデーションを用いて20以上のモデルにてトレーニング、評価されました。
すごく便利ですね。

モデルの作成

今回は以下の2つを実行していきたいと思います。

  • AdaBoost Regressor ('anda')
  • 決定木 ('dt')

元のチュートリアルでは Light Gradient Boosting Machine ('lightgbm')も実行していますが、 magicode では固まってしまったので、この記事では実行は省略します。

PyCaretのモデルライブラリには25個のリグレッサーがあります。全てのレグレッサーのリストを見るには、docstringを確認するか、models関数を使ってライブラリを確認してください。

models()

Name Reference Turbo
ID
lr Linear Regression sklearn.linear_model._base.LinearRegression True
lasso Lasso Regression sklearn.linear_model._coordinate_descent.Lasso True
ridge Ridge Regression sklearn.linear_model._ridge.Ridge True
en Elastic Net sklearn.linear_model._coordinate_descent.Elast... True
lar Least Angle Regression sklearn.linear_model._least_angle.Lars True
llar Lasso Least Angle Regression sklearn.linear_model._least_angle.LassoLars True
omp Orthogonal Matching Pursuit sklearn.linear_model._omp.OrthogonalMatchingPu... True
br Bayesian Ridge sklearn.linear_model._bayes.BayesianRidge True
ard Automatic Relevance Determination sklearn.linear_model._bayes.ARDRegression False
par Passive Aggressive Regressor sklearn.linear_model._passive_aggressive.Passi... True
ransac Random Sample Consensus sklearn.linear_model._ransac.RANSACRegressor False
tr TheilSen Regressor sklearn.linear_model._theil_sen.TheilSenRegressor False
huber Huber Regressor sklearn.linear_model._huber.HuberRegressor True
kr Kernel Ridge sklearn.kernel_ridge.KernelRidge False
svm Support Vector Regression sklearn.svm._classes.SVR False
knn K Neighbors Regressor sklearn.neighbors._regression.KNeighborsRegressor True
dt Decision Tree Regressor sklearn.tree._classes.DecisionTreeRegressor True
rf Random Forest Regressor sklearn.ensemble._forest.RandomForestRegressor True
et Extra Trees Regressor sklearn.ensemble._forest.ExtraTreesRegressor True
ada AdaBoost Regressor sklearn.ensemble._weight_boosting.AdaBoostRegr... True
gbr Gradient Boosting Regressor sklearn.ensemble._gb.GradientBoostingRegressor True
mlp MLP Regressor sklearn.neural_network._multilayer_perceptron.... False
lightgbm Light Gradient Boosting Machine lightgbm.sklearn.LGBMRegressor True
dummy Dummy Regressor sklearn.dummy.DummyRegressor True

AdaBoost Regressor

ada = create_model('ada')

MAE MSE RMSE R2 RMSLE MAPE
Fold
0 4101.8809 23013830.0177 4797.2732 0.7473 0.4758 0.5470
1 4251.5693 29296751.6657 5412.6474 0.7755 0.4940 0.5702
2 4047.8474 22291660.1785 4721.4045 0.7955 0.5068 0.5871
3 4298.3867 23482783.6839 4845.9038 0.7409 0.5089 0.5960
4 3888.5584 24461807.7242 4945.8880 0.6949 0.4764 0.5461
5 4566.4889 29733914.8752 5452.8813 0.7462 0.5462 0.6598
6 4628.7271 27841092.1974 5276.4659 0.7384 0.5549 0.6676
7 4316.4317 25979752.0083 5097.0336 0.6715 0.5034 0.5858
8 3931.2163 21097072.3513 4593.1549 0.7928 0.4858 0.5513
9 4291.1097 24815566.0009 4981.5225 0.7637 0.5495 0.6592
Mean 4232.2217 25201423.0703 5012.4175 0.7467 0.5102 0.5970
Std 233.2282 2804219.3826 277.6577 0.0375 0.0284 0.0457

Decision Tree

dt = create_model('dt')

MAE MSE RMSE R2 RMSLE MAPE
Fold
0 859.1907 2456840.0599 1567.4310 0.9730 0.1016 0.0727
1 1122.9409 9852564.2047 3138.8795 0.9245 0.1102 0.0758
2 911.3452 2803662.6885 1674.4141 0.9743 0.0988 0.0729
3 1002.5575 3926739.3726 1981.6002 0.9567 0.1049 0.0772
4 1167.8154 9751516.1909 3122.7418 0.8784 0.1226 0.0876
5 1047.7778 7833770.7037 2798.8874 0.9331 0.1128 0.0791
6 1010.0816 3989282.4802 1997.3188 0.9625 0.1106 0.0803
7 846.8085 2182534.9007 1477.3405 0.9724 0.0933 0.0709
8 1001.8451 4904945.0821 2214.7111 0.9518 0.1053 0.0734
9 1060.8742 5354347.6956 2313.9463 0.9490 0.1230 0.0847
Mean 1003.1237 5305620.3379 2228.7271 0.9476 0.1083 0.0775
Std 100.2165 2734194.7557 581.7181 0.0280 0.0091 0.0052

Light Gradient Boosting Machine

lightgbm = create_model('lightgbm')

モデルのチューニング

create_model関数を使ってモデルを作成すると,デフォルトのハイパーパラメータを使ってモデルを学習します.ハイパーパラメータを調整するには,tune_model関数を使用します.この関数は,あらかじめ定義された探索空間において,Random Grid Searchを用いてモデルのハイパーパラメータを自動的に調整します.出力では,MAE,MSE,RMSE,R2,RMSLE,MAPEをフォールドごとに示したスコアグリッドを印刷します.

AdaBoost Regressor

tuned_ada = tune_model(ada)

MAE MSE RMSE R2 RMSLE MAPE
Fold
0 2629.7158 16222922.0054 4027.7689 0.8219 0.2553 0.2244
1 2764.7250 25273189.9003 5027.2448 0.8063 0.2714 0.2357
2 2605.9909 16883405.3119 4108.9421 0.8451 0.2617 0.2352
3 2588.0395 14475338.1062 3804.6469 0.8403 0.2685 0.2271
4 2403.7173 13602075.2435 3688.0991 0.8303 0.2672 0.2223
5 2538.7416 20724600.2592 4552.4280 0.8231 0.2644 0.2260
6 2720.2195 19796302.1522 4449.3036 0.8140 0.2644 0.2280
7 2707.6016 17084596.1502 4133.3517 0.7839 0.2743 0.2475
8 2444.0262 16340453.5625 4042.3327 0.8395 0.2623 0.2199
9 2545.6132 19267454.7853 4389.4709 0.8165 0.2680 0.2247
Mean 2594.8391 17967033.7477 4222.3589 0.8221 0.2657 0.2291
Std 111.1423 3238932.6224 372.4506 0.0174 0.0051 0.0078
print(tuned_ada)

Decision Tree

tuned_dt = tune_model(dt)

MAE MSE RMSE R2 RMSLE MAPE
Fold
0 1000.7122 2895159.1309 1701.5167 0.9682 0.1076 0.0828
1 1080.2841 6686388.0416 2585.8051 0.9488 0.1053 0.0814
2 1002.3163 3275429.6329 1809.8148 0.9700 0.1051 0.0812
3 1080.7850 4037154.5985 2009.2672 0.9555 0.1172 0.0870
4 1101.6333 7889520.5391 2808.8290 0.9016 0.1189 0.0842
5 1275.5901 11021312.1970 3319.8362 0.9059 0.1250 0.0895
6 1068.6534 4463866.3029 2112.7864 0.9581 0.1076 0.0809
7 975.9364 3271028.5175 1808.5985 0.9586 0.1099 0.0807
8 1101.9207 4441966.3616 2107.5973 0.9564 0.1114 0.0873
9 1065.1662 5192339.2748 2278.6705 0.9506 0.1224 0.0873
Mean 1075.2997 5317416.4597 2254.2722 0.9474 0.1130 0.0842
Std 79.0463 2416581.2427 485.4621 0.0227 0.0069 0.0031

Light Gradient Boosting Machine

カスタム検索グリッドを使用するには,tune_model関数でcustom_gridパラメータを渡します

import numpy as np
lgbm_params = {'num_leaves': np.arange(10,200,10),
                        'max_depth': [int(x) for x in np.linspace(10, 110, num = 11)],
                        'learning_rate': np.arange(0.1,1,0.1)
tuned_lightgbm = tune_model(lightgbm, custom_grid = lgbm_params)

デフォルトでは,tune_model は R2 を最適化しますが,optimize パラメータを用いてこれを変更することができます.例えば,tune_model(dt, optimize = 'MAE')は,最高のR2ではなく,最低のMAEとなる決定木回帰因子のハイパーパラメータを検索します.この例では、わかりやすくするために、デフォルトの指標である「R2」を使用しています。リグレッサーを評価するために適切なメトリクスを選択する方法は、このチュートリアルの範囲を超えていますが、もっと詳しく知りたい方は、click hereで回帰誤差メトリクスに関する理解を深めることができます。

今回は Decision Tree(決定木)のモデルを使い、チュートリアルを進めます。

モデルをプロット

モデルを完成させる前に,plot_model()関数を使って,残差プロット,予測誤差,特徴の重要度など,さまざまな側面から性能を分析することができます.この関数は、学習されたモデルオブジェクトを受け取り、テスト/ホールドアウトセットに基づいたプロットを返します。 10種類以上のプロットが用意されていますので、利用可能なプロットのリストは plot_model() のドキュメントをご覧ください。

Residual Plot 残差プロット

plot_model(dt)

findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
<Figure size 576x396 with 2 Axes>

Prediction Error Plot 予測誤差のプロット

plot_model(tuned_dt, plot = 'error')

<Figure size 576x396 with 1 Axes>

Feature Importance Plot フィーチャー・インポータンス・プロット

plot_model(tuned_dt, plot='feature')

<Figure size 800x500 with 1 Axes>

モデルの性能 を分析するもう一つの方法は,evaluate_model()関数を使うことです.この関数は,与えられたモデルについて利用可能なすべてのプロットのためのユーザインタフェースを表示します.この関数は,内部的には plot_model() 関数を使用しています.

evaluate_model(tuned_lightgbm)

残念ながら magicode では、うまく表示できませんでしたが、 Colab では実施できました。

Predict on Test / Hold-out Sample

モデルを最終的に決定する前に、テスト/ホールドアウトセットを予測し、評価指標を確認することで、最終的なチェックを行うことが推奨されます。

predict_model(tuned_dt);

Model MAE MSE RMSE R2 RMSLE MAPE
0 Decision Tree Regressor 1078.2157 7456096.0505 2730.5853 0.9320 0.1148 0.0822

テスト/ホールドアウトの結果とCVの結果の間に大きな差がなく問題ないことがわかります。

モデルを最終的に決定し、未経験のデータ(最初に分けた10%のデータで、PyCaretに触れていないデータ)で予測することに進みます。

(TIP : create_model を使用する際に、CV結果の標準偏差を見ることは常に良いことです。)

Finalize Model for Deployment

モデルの最終決定は、実験の最後のステップです。PyCaretでの通常の機械学習のワークフローは、setup()から始まり、compare_models()で全てのモデルを比較し、ハイパーパラメータチューニング、アンサンブル、スタッキングなどのモデリング技術を実行するためのいくつかの候補モデルを(対象となる指標に基づいて)選びます。このワークフローにより、新しいデータや未知のデータの予測に使用するための最適なモデルを最終的に導き出すことができます。finalize_model()関数は、テスト/ホールドアウトサンプル(ここでは30%)を含む完全なデータセットにモデルをフィットさせます。この関数の目的は,モデルを実運用に投入する前に,完全なデータセットでモデルを訓練することです.

final_dt = finalize_model(tuned_dt)

/srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1254: FutureWarning: the classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1262: FutureWarning: the n_classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1254: FutureWarning: the classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1262: FutureWarning: the n_classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1254: FutureWarning: the classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1262: FutureWarning: the n_classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1254: FutureWarning: the classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1262: FutureWarning: the n_classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1254: FutureWarning: the classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1262: FutureWarning: the n_classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1254: FutureWarning: the classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1262: FutureWarning: the n_classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1254: FutureWarning: the classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1262: FutureWarning: the n_classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1254: FutureWarning: the classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1262: FutureWarning: the n_classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning)
/srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1254: FutureWarning: the classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1262: FutureWarning: the n_classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1254: FutureWarning: the classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning) /srv/conda/envs/notebook/lib/python3.7/site-packages/sklearn/tree/_classes.py:1262: FutureWarning: the n_classes_ attribute is to be deprecated from version 0.22 and will be removed in 0.24. warnings.warn(msg, FutureWarning)
print(final_dt)

DecisionTreeRegressor(ccp_alpha=0.0, criterion='friedman_mse', max_depth=10, max_features=1.0, max_leaf_nodes=None, min_impurity_decrease=0.01, min_impurity_split=None, min_samples_leaf=2, min_samples_split=9, min_weight_fraction_leaf=0.0, presort='deprecated', random_state=123, splitter='best')
predict_model(final_dt);

Model MAE MSE RMSE R2 RMSLE MAPE
0 Decision Tree Regressor 769.7208 2705921.7672 1644.9686 0.9753 0.0831 0.0623

Predict on Unseen Data

見ていないデータセットに対する予測にも,predict_model()関数を使います.上のセクション11との唯一の違いは、今回はdata_unseenというパラメータを渡すことです。data_unseen`はチュートリアルの最初に作成された変数で、PyCaretに公開されていないオリジナルのデータセットの10%(600サンプル)を含んでいます。(

unseen_predictions = predict_model(final_dt, data=data_unseen)
unseen_predictions.head()

Model MAE MSE RMSE R2 RMSLE MAPE
0 Decision Tree Regressor 1100.4945 4101670.6444 2025.2582 0.9601 0.1042 0.0809
Carat Weight Cut Color Clarity Polish Symmetry Report Price Label
0 1.53 Ideal E SI1 ID ID AGSL 12791 11100.000000
1 1.50 Fair F SI1 VG VG GIA 10450 11258.107143
2 1.01 Good E SI1 G G GIA 5161 5243.827586
3 2.51 Very Good G VS2 VG VG GIA 34361 38788.600000
4 1.01 Good I SI1 VG VG GIA 4238 4107.533333

data_unseenにLabel列が追加されています。Labelは,final_lightgbmモデルを用いた予測値です.予測値を丸めたい場合には、predict_model()の中でroundパラメータを使用します。また、実際のターゲットカラムである Price が利用できるので、これに関するメトリクスをチェックすることもできます。そのためには、pycaret.utilsモジュールを使います。以下の例をご覧ください。

from pycaret.utils import check_metric
check_metric(unseen_predictions.Price, unseen_predictions.Label, 'R2')

0.9601

モデルの save と load

保存は save_model() を使い、 load は load_model() を使います。

save_model(final_dt,'Final DT Model')

Transformation Pipeline and Model Successfully Saved
(Pipeline(memory=None, steps=[('dtypes', DataTypes_Auto_infer(categorical_features=[], display_types=True, features_todrop=[], id_columns=[], ml_usecase='regression', numerical_features=[], target='Price', time_features=[])), ('imputer', Simple_Imputer(categorical_strategy='not_available', fill_value_categorical=None, fill_value_numerical=None, numeric_strategy='... ('dfs', 'passthrough'), ('pca', 'passthrough'), ['trained_model', DecisionTreeRegressor(ccp_alpha=0.0, criterion='friedman_mse', max_depth=10, max_features=1.0, max_leaf_nodes=None, min_impurity_decrease=0.01, min_impurity_split=None, min_samples_leaf=2, min_samples_split=9, min_weight_fraction_leaf=0.0, presort='deprecated', random_state=123, splitter='best')]], verbose=False), 'Final DT Model.pkl')
saved_final_dt = load_model('Final DT Model')

Transformation Pipeline and Model Successfully Loaded

save & load したモデルを使って再度予測を実行してみます。

new_prediction = predict_model(saved_final_dt, data=data_unseen)

Model MAE MSE RMSE R2 RMSLE MAPE
0 Decision Tree Regressor 1100.4945 4101670.6444 2025.2582 0.9601 0.1042 0.0809
new_prediction.head()

Carat Weight Cut Color Clarity Polish Symmetry Report Price Label
0 1.53 Ideal E SI1 ID ID AGSL 12791 11100.000000
1 1.50 Fair F SI1 VG VG GIA 10450 11258.107143
2 1.01 Good E SI1 G G GIA 5161 5243.827586
3 2.51 Very Good G VS2 VG VG GIA 34361 38788.600000
4 1.01 Good I SI1 VG VG GIA 4238 4107.533333
from pycaret.utils import check_metric
check_metric(new_prediction.Price, new_prediction.Label, 'R2')

0.9601

無事、予測できていることがわかります。

このチュートリアルでは、データの取り込み、前処理、モデルのトレーニング、ハイパーパラメータのチューニング、予測、モデルの保存など、機械学習のパイプライン全体をカバーしています。これらのステップは10個以下のコマンドで完了しており、「create_model()」、「tune_model()」、「compare_models()」などの自然な構成で、直感的に覚えられます。この実験全体をPyCaretなしで再作成すると、ほとんどのライブラリで100行以上のコードが必要になります。

以上、 PyCaret の回帰分析のビギナーのチュートリアルでした。 数行書くだけで機械学習できちゃうのは本当に便利ですね。

次のチュートリアルでは、高度な前処理、アンサンブル、一般化されたスタッキングなどが含まれているとのことですので、そちらも次回試してみたいと思います。

Discussion

コメントにはログインが必要です。