Magicode logo
Magicode
0
1 min read

決定木の可視化(plot_treeを使う)

from sklearn.datasets import load_iris
iris = load_iris()

data = iris.data
target = iris.target

from sklearn.model_selection import train_test_split
train_x, test_x, train_y, test_y = train_test_split(data,target,random_state=1)

from sklearn import tree
model = tree.DecisionTreeClassifier(max_depth=5)
model.fit(train_x, train_y)
model.predict(test_x)
model.score(test_x,test_y)
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 10))
plot_tree(
    model.fit(train_x, train_y), 
    filled=True, 
    rounded=True, 
    feature_names=iris.feature_names, 
    class_names=iris.target_names 
)
plt.show()

graphvizを使ってできない・・・

Discussion

コメントにはログインが必要です。