python 의사 결정 트 리 분류 실현
원본 데이터 세트:
변 화 된 데이터 세트 가 프로그램 코드 에 나타 나 면 캡 처 하지 않 습 니 다.
의사 결정 트 리 구축 코드 는 다음 과 같 습 니 다.
#coding :utf-8
'''
2017.6.25 author :Erin
function: "decesion tree" ID3
'''
import numpy as np
import pandas as pd
from math import log
import operator
def load_data():
#data=np.array(data)
data=[['teenager' ,'high', 'no' ,'same', 'no'],
['teenager', 'high', 'no', 'good', 'no'],
['middle_aged' ,'high', 'no', 'same', 'yes'],
['old_aged', 'middle', 'no' ,'same', 'yes'],
['old_aged', 'low', 'yes', 'same' ,'yes'],
['old_aged', 'low', 'yes', 'good', 'no'],
['middle_aged', 'low' ,'yes' ,'good', 'yes'],
['teenager' ,'middle' ,'no', 'same', 'no'],
['teenager', 'low' ,'yes' ,'same', 'yes'],
['old_aged' ,'middle', 'yes', 'same', 'yes'],
['teenager' ,'middle', 'yes', 'good', 'yes'],
['middle_aged' ,'middle', 'no', 'good', 'yes'],
['middle_aged', 'high', 'yes', 'same', 'yes'],
['old_aged', 'middle', 'no' ,'good' ,'no']]
features=['age','input','student','level']
return data,features
def cal_entropy(dataSet):
'''
data ,
{' ': 9, ' ': 5}
0.9402859586706309
'''
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
label = featVec[-1]
if label not in labelCounts.keys():
labelCounts[label] = 0
labelCounts[label] += 1
entropy = 0.0
for key in labelCounts.keys():
p_i = float(labelCounts[key]/numEntries)
entropy -= p_i * log(p_i,2)#log(x,10) 10
return entropy
def split_data(data,feature_index,value):
'''
feature_index: , “ ”
value: : “ ”
'''
data_split=[]#
for feature in data:
if feature[feature_index]==value:
reFeature=feature[:feature_index]
reFeature.extend(feature[feature_index+1:])
data_split.append(reFeature)
return data_split
def choose_best_to_split(data):
'''
,
'''
count_feature=len(data[0])-1# 4
#print(count_feature)#4
entropy=cal_entropy(data)#
#print(entropy)#0.9402859586706309
max_info_gain=0.0#
split_fea_index = -1# ,
for i in range(count_feature):
feature_list=[fe_index[i] for fe_index in data]#
#######################################
'''
print('feature_list')
[' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ',
' ', ' ', ' ', ' ']
0.3467680694480959 # =(1)*5/14
0.3467680694480959
0.6935361388961918
'''
# print(feature_list)
unqval=set(feature_list)#
Pro_entropy=0.0#
for value in unqval:#
sub_data=split_data(data,i,value)
pro=len(sub_data)/float(len(data))
Pro_entropy+=pro*cal_entropy(sub_data)
#print(Pro_entropy)
info_gain=entropy-Pro_entropy
if(info_gain>max_info_gain):
max_info_gain=info_gain
split_fea_index=i
return split_fea_index
##################################################
def most_occur_label(labels):
#sorted_label_count[0][0]
label_count={}
for label in labels:
if label not in label_count.keys():
label_count[label]=0
else:
label_count[label]+=1
sorted_label_count = sorted(label_count.items(),key = operator.itemgetter(1),reverse = True)
return sorted_label_count[0][0]
def build_decesion_tree(dataSet,featnames):
'''
,
'''
featname = featnames[:] ################
classlist = [featvec[-1] for featvec in dataSet] #
if classlist.count(classlist[0]) == len(classlist): #
return classlist[0]
if len(dataSet[0]) == 1: # ,
return Vote(classlist) #
#
bestFeat = choose_best_to_split(dataSet)
bestFeatname = featname[bestFeat]
del(featname[bestFeat]) #
DecisionTree = {bestFeatname:{}}
# , ,
allvalue = [vec[bestFeat] for vec in dataSet]
specvalue = sorted(list(set(allvalue))) #
for v in specvalue:
copyfeatname = featname[:]
DecisionTree[bestFeatname][v] = build_decesion_tree(split_data(dataSet,bestFeat,v),copyfeatname)
return DecisionTree
시각 화 된 그림 을 그 리 는 코드 는 다음 과 같 습 니 다.
def getNumLeafs(myTree):
' '
#
numLeafs = 0
#
sides = list(myTree.keys())
firstStr =sides[0]
#
secondDict = myTree[firstStr]
for key in secondDict.keys(): #
#
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
# +1
else: numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
' '
#
maxDepth = 0
#
sides = list(myTree.keys())
firstStr =sides[0]
#
secondDict = myTree[firstStr]
for key in secondDict.keys(): #
#
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
# +1
else: thisDepth = 1
#
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# ==================================================
# :
# nodeTxt:
# centerPt:
# parentPt:
# nodeType:
# :
# ( )
# ==================================================
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
' ( )'
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
# =================================================================
# :
# cntrPt:
# parentPt:
# txtString:
# :
# (cntrPt parentPt ) (txtString)
# =================================================================
def plotMidText(cntrPt, parentPt, txtString):
' '
#
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
# ===================================
# :
# myTree:
# parentPt:
# nodeTxt:
# :
#
# ===================================
def plotTree(myTree, parentPt, nodeTxt):
' '
#
numLeafs = getNumLeafs(myTree)
#
sides = list(myTree.keys())
firstStr =sides[0]
# ( )
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# ( )
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
#
secondDict = myTree[firstStr]
# , -1。
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys(): #
#
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
#
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
# , +1。
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
# ==============================
# :
# myTree:
# :
#
# ==============================
def createPlot(inTree):
' '
# -
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
#
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
#
plotTree.xOff = -0.5/plotTree.totalW;
plotTree.yOff = 1.0;
#
plotTree(inTree, (0.5,1.0), '')
plt.show()
if __name__ == '__main__':
data,features=load_data()
split_fea_index=choose_best_to_split(data)
newtree=build_decesion_tree(data,features)
print(newtree)
createPlot(newtree)
'''
{'age': {'old_aged': {'level': {'same': 'yes', 'good': 'no'}}, 'teenager': {'student': {'no': 'no', 'yes': 'yes'}}, 'middle_aged': 'yes'}}
'''
결 과 는 다음 과 같다.의사 결정 트 리 로 어떻게 분류 하 는 지다음 장.
이상 이 바로 본 고의 모든 내용 입 니 다.여러분 의 학습 에 도움 이 되 고 저 희 를 많이 응원 해 주 셨 으 면 좋 겠 습 니 다.
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
로마 숫자를 정수로 또는 그 반대로 변환그 중 하나는 로마 숫자를 정수로 변환하는 함수를 만드는 것이었고 두 번째는 그 반대를 수행하는 함수를 만드는 것이었습니다. 문자만 포함합니다'I', 'V', 'X', 'L', 'C', 'D', 'M' ; 문자열이 ...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.