다중 태그 KNN 알고리즘 구현(Python3.6)
5188 단어 01
MLKNN 알고리즘
새로운 실례에 대해 가장 가까운 k개의 실례를 취한 다음에 k개의 실례로 구성된 라벨 집합을 얻고 마지막으로 선험 확률과 최대 후험 확률을 통해 새로운 실례의 라벨 집합을 확정한다.상세한 내용은 주지화 선생님과 장민령 선생님의 《다중 표기 학습》을 참조하시오.
알고리즘 실현
데이터 사전 처리
MLkNNDemo.py
#load file
data = sio.loadmat('scene.mat')
train_bags = data['train_bags']
test_bags = data['test_bags']
train_targets = data['train_targets']
test_targets = data['test_targets']
#train_bags:9*15 to 1*135
trainBagLine = len(train_bags)
train_data = []
for i in range(trainBagLine):
linshi = train_bags[i,0].flatten().tolist()
train_data.append(linshi)
train_data=np.array(train_data)
#test_bags:9*15 to 1*135
testBagLine = len(test_bags)
test_data = []
for i in range(testBagLine):
linshi = test_bags[i,0].flatten().tolist()
test_data.append(linshi)
test_data=np.array(test_data)
#const;Num is K's value
Num = 10
Smooth = 1
#training;
Prior,PriorN,Cond,CondN = MLkNNTrain.trainClass(train_data,train_targets,Num,Smooth)
#testing
outPuts,preLabels = MLkNNTest.testClass(train_data,train_targets,test_data,test_targets,Num,Prior,PriorN,Cond,CondN)
훈련
MLkNNTrain.py
#Train
def trainClass(train_data,train_targets,Num,Smooth):
#Get size of matrix
num_class,num_training = np.mat(train_targets).shape
dist_matrix = np.diagflat(ones((1,num_training))*sys.maxsize)
#Cumputing distance
for i in range((num_training-1)):
vector1 = train_data[i,:]
for j in range((i+1),(num_training)):
vector2 = train_data[j,:]
dist_matrix[i,j] = sum((vector1-vector2)**2)**0.5
dist_matrix[j,i] = dist_matrix[i,j]
#Prior and PriorN
Prior = zeros((num_class,1))
PriorN = zeros((num_class,1))
for i in range(num_class):
tempCi = sum((train_targets[i,:]==ones((1,num_training))))
Prior[i,0] = (tempCi+1)/(Smooth*2+num_training)
PriorN[i,0] = 1-Prior[i,0]
#Cond and CondN
#Sort by distance and get index
#find neighbors
disMatIndex = argsort(dist_matrix)
tempCi = zeros((num_class,Num+1))
tempNci = zeros((num_class,Num+1))
for i in range(num_training):
temp = zeros((1,num_class))
neighborLabels = []
for j in range(Num):
neighborLabels.append(train_targets[:,disMatIndex[i,j]])
neighborLabels = np.mat(neighborLabels)
neighborLabels = np.transpose(neighborLabels)
for j in range(num_class):
temp[0,j] = sum((neighborLabels[j,:] == ones((1,Num))))
for j in range(num_class):
t = int((temp[0,j]))
if(train_targets[j,i] == 1):
tempCi[j,t]=tempCi[j,t]+1
else:
tempNci[j,t] = tempNci[j,t]+1
#get 5*11 matrix
Cond = zeros((num_class,Num+1))
CondN = zeros((num_class,Num+1))
for i in range(num_class):
temp1 = sum((tempCi[i,:]))
temp2 = sum((tempNci[i,:]))
for j in range(Num+1):
Cond[i,j] = (Smooth+tempCi[i,j])/(Smooth*(Num+1)+temp1)
CondN[i,j] = (Smooth+tempNci[i,j])/(Smooth*(Num+1)+temp2)
return Prior,PriorN,Cond,CondN
테스트
MLkNNTest.py
def testClass(train_data,train_targets,test_data,test_targets,Num,Prior,PriorN,Cond,CondN):
num_class,num_training = np.mat(train_targets).shape
num_class,num_testing = np.mat(test_targets).shape
#init matrix about distance
distMatrix = zeros((num_testing,num_training))
for i in range(num_testing):
vector1 = test_data[i,:]
for j in range(num_training):
vector2 = train_data[j,:]
distMatrix[i,j] = sum((vector1-vector2)**2)**0.5
#Sort by distance and get index
#find neighbors
disMatIndex = argsort(distMatrix)
#computing outputs
outPuts = zeros((num_class,num_testing))
for i in range(num_testing):
temp = zeros((1,num_class))
neighborLabels = []
for j in range(Num):
neighborLabels.append(train_targets[:,disMatIndex[i,j]])
neighborLabels = np.mat(neighborLabels)
#transposition
neighborLabels = np.transpose(neighborLabels)
for j in range(num_class):
temp[0,j] = sum((neighborLabels[j,:] == ones((1,Num))))
for j in range(num_class):
t = int((temp[0,j]))
Prob_in=Prior[j]*Cond[j,t]
Prob_out=PriorN[j]*CondN[j,t]
if((Prob_in+Prob_out)==0):
outPuts[j,i]=Prior[j]
else:
outPuts[j,i]=Prob_in/[Prob_in+Prob_out]
#Evaluation
preLabels=zeros((num_class,num_testing))
for i in range(num_testing):
for j in range(num_class):
if(outPuts[j,i]>=0.5): # 0.5
preLabels[j,i]=1
else:
preLabels[j,i]=-1
return outPuts,preLabels
주: 상술한 내용은 단지 개인의 학습 과정 중의 필기일 뿐이니, 만약 부적절한 부분이 있으면 바로잡아 주시기 바랍니다.