파이썬 결정 트리를 dtreeviz로 스마트하게 시각화

소개



결정 트리는 설명 가능성이 높고 유용한 기법이지만 파이썬에서는 시각화가 어색하기 때문에 선택에 들어가기 어려워졌다고 개인적으로 생각합니다.
그런 가운데, dtreeviz라는 라이브러리가 공개되어 깨끗하게 가시화할 수 있게 됐어! 라고 이야기.

파이썬 결정 트리 시각화 Before/After



먼저 어떻게 바뀌었는지를 보여주는 것이 알기 쉽기 때문에, iris 데이터를 사용한 결정 트리의 예.

결정 트리 학습
from sklearn.datasets import load_iris
from sklearn import tree

clf = tree.DecisionTreeClassifier(max_depth=2)  # limit depth of tree
iris = load_iris()
clf.fit(iris.data, iris.target)


Before



아마도 주요 graphviz에서 시각화

graphviz에 의한 결정 트리 시각화
import pydotplus
from IPython.display import Image
from graphviz import Digraph

dot_data = tree.export_graphviz(
    clf,
    out_file=None,
    feature_names=iris.feature_names,
    class_names=iris.target_names,
    filled=True,
    proportion=True)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())



After



dtreeviz라고 이런 느낌

dtreeviz로 시각화
from dtreeviz.trees import dtreeviz

viz = dtreeviz(
    clf,
    iris.data, 
    iris.target,
    target_name='variety',
    feature_names=iris.feature_names,
    class_names=[str(i) for i in iris.target_names],
) 

viz.view()



dtreeviz에서 바뀐 곳


  • 어쨌든 디자인이 멋졌다
  • 이대로 설명 자료에 사용할 수 있을 것 같다

  • 분류에 사용되는 특징 량의 분포와 결정 경계가 나타났다
  • 잎의 상세한 순도가 떨어지는 대신 원형 차트 구성 비율로 이해하기 쉬워졌습니다
  • 샘플 데이터를 제공하면 추론 과정과 근거가되는 특징을 표시 할 수 있습니다.

    결정 트리 추론 과정의 시각화
    X = iris.data[29]  # サンプルデータ
    
    viz = dtreeviz(
        classifier,
        iris.data, 
        iris.target,
        target_name='variety',
        feature_names=iris.feature_names,
        class_names=[str(i) for i in iris.target_names],    
        X=X, # サンプルデータを与えると、分類の過程が表示される
    ) 
    
    viz.view()
    



    덧붙여서, 여기까지는 분류 나무의 예뿐이었지만, 회귀 나무의 가시화도 할 수 있습니다.
    (단지 회귀목을 적극적으로 사용할 기회는 별로 없네요・・・.)

    dtreeviz 함수의 인수




    인수
    금형
    기본값
    설명


    tree_model

    sklearn의 DecisionTreeRegressor 또는 DecisionTreeClassifier

    X_train
    (pd.DataFrame, np.ndarray)

    모델 훈련에 사용된 설명 변수 데이터

    y_train
    (pd.Series, np.ndarray)

    모델 훈련에 사용된 목적 변수 데이터

    feature_names
    List[str]

    X_train의 각 특징량명

    target_name
    str

    목적 변수의 이름

    class_names
    (Mapping[Number, str], List[str])

    (분류 트리의 경우 필수) 각 클래스에 해당하는 이름

    precision
    int
    2
    특징량의 경계치를 표시하는 소수점 이하의 자리수

    orientation
    ( 'TD', 'LR')
    'TD'
    나무가 분기하는 방향 TD: top-down or LR: left-right

    show_root_edge_labels
    bool
    True
    루트에서 노드로의 분기 값 관계를 표시할지 여부

    show_node_labels
    bool
    거짓
    노드 번호를 표시하거나

    fancy
    bool
    True
    특징량의 결정 경계를 시각화할까

    histtype
    ( 'bar', 'barstacked')
    'barstacked'
    분류 트리 때, 히스토그램 표시 형식

    X
    np.ndarray
    None
    추론 과정을 시각화하는 샘플 데이터

    max_X_features_LR
    int
    10
    orientation='LR'일 때 표시할 샘플 데이터의 특징량 수

    max_X_features_TD
    int
    20
    orientation='TD'일 때 표시할 샘플 데이터의 특징량 수


    주의



    jupyter lab이나 google colab의 경우 view를 사용할 수 없으므로 display(viz)로 시각화해야합니다.

    참고


  • dtreeviz : Decision Tree Visualization
  • 결정 트리의 시각화 라이브러리 「dtreeviz」가 굉장했기 때문에 정리한다
  • 좋은 웹페이지 즐겨찾기