tensorflow 계산 혼동 매트릭스와 각종 평가 지표 실현

2186 단어 tensorflow
tf.confusion_matrix(y,pred_y,num_classes)
두 가지 분류를 예로 들면 혼동 행렬이 2*2인 행렬이다. 만약에 우리의 실제 라벨이real=[0,1,1,0,1]이면 예측 라벨은predict=[0,1,0,1]이다.
num_classes는 분류수입니다. 이걸 꼭 설정해야 합니다!!!그렇지 않으면 기본값은 None 테스트입니다.
import numpy as np
import tensorflow as tf
y=np.array([[1,0],[1,0],[1,0],[1,0],[1,0]])
y=tf.convert_to_tensor(y)
predict=np.array([[1,0],[1,0],[1,0],[1,0],[1,0]])
predict=tf.convert_to_tensor(predict)

confusion_matrix=tf.confusion_matrix(tf.argmax(y,1),tf.argmax(predict,1),num_classes=2)
       
with tf.Session() as sess: #       
    matrix=sess.run(confusion_matrix)
    print(matrix)
  #  [[5 0] [0 0]]
 #    num_classes=2    [[5]]

2 평가 지표
혼동 행렬을 얻으면 TP, TN, FP, FN을 얻을 수 있다.각종 지표의 공식에 근거하여 구할 수 있다
def evaluate(confusion_metrics):
    TP=confusion_metrics[0][0]
    FP=confusion_metrics[0][1]
    FN=confusion_metrics[1][0]
    TN=confusion_metrics[1][1]
    
    
    ACC=(TP+TN)/(TP+TN+FP+FN)
    SEN=TP/(TP+FN)
    SPE=TN/(TN+FP)

    return ACC,SEN,SPE

삼실훈련
실제 훈련을 할 때 우리의 데이터는 모두 차례로 네트워크에 넣고 훈련하기 때문에 2*2의 0 행렬을 초기화하고 한 번에 혼동 행렬을 계산한 다음에 합칠 수 있다
        real=tf.argmax(y,1) #one-hot  
        pred=tf.argmax(predict,1)
        confusion_matrix=tf.confusion_matrix(real,pred,num_classes=2)

        with tf.Session() as sess: #      
            sess.run(tf.global_variables_initializer())
            avg_cost=0.
            all_matrix=np.zeros([2,2])
            for i in range(300):
                _x,_y=sess.run([train_x,train_y])
              
                matrix=sess.run([confusion_matrix],feed_dict={x:_x,y:_y})
                all_matrix+=matrix
                avg_cost+=cost/display_step
                
                if (i+1)%display_step==0: #  display_step    cost         0
                    train_acc,train_sen,train_spe=evaluate(all_matrix)           
                    print("step:%d,train_acc:%s,train_spe:%s,train_sen:%s"%(i+1,str(train_acc),str(train_spe),str(train_sen)))
                    avg_cost=0
                    all_matrix=np.zeros([2,2])

                    

주의: matrix는numpy 형식이기 때문에np를 사용해야 합니다.zeros([2,2])로 초기화

좋은 웹페이지 즐겨찾기