Matlab 별이 있는 ROC 커브 그리기

function [AUC betterthreshold MCCMAXThreshold Confuse_Matrix MCCs FPRs TPRs F1S] = ROC_2_Class_OneFile_WithStar(begin_threshold, end_threshold, step, ismaxmcc,classificationFileName, star_num, star_style)
%  
%   The format of the file with path 'classificationFileName':
%       truelabel  proabilityOfFirstClass   proabilityOfSecondClass
%           1           0.78                        0.22
%           1           0.68                        0.32
%           2           0.56                        0.44
%           1           0.89                        0.11
%           2           0.22                        0.78
%           ....        ....                        ....
%   ismaxmcc
%  e.g., [AUC betterthreshold MCCMAXThreshold Confuse_Matrix MCCs FPRs TPRs F1S] = ROC_2_Class_OneFile_WithStar(0, 1, 0.01, 1, 'classificationFileName', 50, 'b--')
%

if nargin < 4, error('NNET:Arguments','Not enough input arguments.'); end
if ((begin_threshold>end_threshold) || step < 0.0)
    error('NNET:Arguments','The pattern of a parameter is not right.');
end

A = load(classificationFileName);
[H L] = size(A);

maxTpr = 0;
maxFpr = 0;
balanceTpr = 0;
balanceFpr = 0;
MCCMAXThreshold = 0.5;
MCCPREThreshold = 0.5;
MCCs = [];
FPRs = [];
TPRs = [];
F1S = [];
tprs = [];
fprs = [];
th_fpr_5 = 0;
th_fpr_6 = 0;
th_fpr_7 = 0;
th_fpr_8 = 0;
th_fpr_9 = 0;
th_fpr_10 = 0;
maxMcc = -9999;
maxPre = -9999;
for threshold=begin_threshold:step:end_threshold
    total1_1 = 0;
    total1_2 = 0;
    total2_1 = 0;
    total2_2 = 0;
    for i=1:H
        if A(i, 2)>=threshold
            if 1==A(i, 1) 
                total1_1 = total1_1 + 1; 
            else
                total2_1 = total2_1 + 1; 
            end
        else 
            if 1==A(i, 1)  
                total1_2 = total1_2 + 1;
            else
                total2_2 = total2_2 + 1;
            end
        end
    end   %end_for_i
    tpr = 1.0*total1_1/(total1_1+total1_2);
    fpr = 1.0*total2_1/(total2_1+total2_2);
    
    pre = 1.0*total1_1/(total1_1+total2_1);
    
    f1_score = (2*tpr*pre) / (tpr + tpr);
    FPRs = [FPRs, fpr];
    TPRs = [TPRs, tpr];
    F1S = [F1S, f1_score];
    if fpr>=0.05
        if fpr<=0.055
            th_fpr_5 = threshold;
        end
    end
    
     if fpr>=0.06
        if fpr<=0.065
            th_fpr_6 = threshold;
        end
     end
    
    if fpr>=0.07
        if fpr<=0.075
            th_fpr_7 = threshold;
        end
     end
    
    if fpr>=0.08
        if fpr<=0.085
            th_fpr_8 = threshold;
        end
    end
    
    if fpr>=0.09
        if fpr<=0.095
            th_fpr_9 = threshold;
        end
    end
    
    if fpr>=0.10
        if fpr<=0.105
            th_fpr_10 = threshold;
        end
    end
    
    tprs = [tprs, tpr];
    fprs = [fprs, fpr];
    
    %calculate mcc
    mcc = (total1_1*total2_2 - total2_1*total1_2)/sqrt((total1_1+total2_1)*(total1_1+total1_2)*(total2_2+total2_1)*(total2_2+total1_2));
    MCCs = [MCCs, mcc];
    if mcc > maxMcc
        maxMcc = mcc;
        MCCMAXThreshold = threshold;
        
        maxTpr = tpr;
        maxFpr = fpr;
    end
    
    %calculate pre
    pre = total1_1/(total1_1+total2_1);
    if pre+mcc>maxPre+maxMcc
        maxPre=pre;
        MCCPREThreshold=threshold;
    end
    
end 

[L K] = size(tprs);

wc = 1;
tag = 0;
for  x=1:K
    if 0~=tprs(1, x) && 1~=tprs(1,x) && 0~=fprs(1, x)&& 1~=fprs(1,x) && wc>abs(tprs(1, x)+fprs(1,x)-1)
        wc = abs(tprs(1, x)+fprs(1,x)-1);
        tag = x;
        
        balanceTpr = tprs(1, x);
        balanceFpr = fprs(1,x);
    end
end

betterthreshold = begin_threshold + tag*step;

AUC = roc_curve(A(:, 2), A(:, 1), star_num, star_style);

if 1 == ismaxmcc
%     fprintf('Max Threshold = %.3f
', MCCMAXThreshold); threshold = MCCMAXThreshold; elseif 0 == ismaxmcc % fprintf('Balance Threshold = %.3f
', betterthreshold); threshold = betterthreshold; elseif 5 == ismaxmcc % fprintf('FPR=0.05 Threshold = %.3f
', th_fpr_5); threshold = th_fpr_5; elseif 6 == ismaxmcc % fprintf('FPR=0.06 Threshold = %.3f
', th_fpr_6); threshold = th_fpr_6; elseif 7 == ismaxmcc % fprintf('FPR=0.07 Threshold = %.3f
', th_fpr_7); threshold = th_fpr_7; elseif 8 == ismaxmcc % fprintf('FPR=0.08 Threshold = %.3f
', th_fpr_8); threshold = th_fpr_8; elseif 9 == ismaxmcc % fprintf('FPR=0.09 Threshold = %.3f
', th_fpr_9); threshold = th_fpr_9; elseif 10 == ismaxmcc % fprintf('FPR=0.10 Threshold = %.3f
', th_fpr_10); threshold = th_fpr_10; end threshold total1_1 = 0; total1_2 = 0; total2_1 = 0; total2_2 = 0; for i=1:H if A(i, 2)>=threshold %disp(i); if 1==A(i, 1) total1_1 = total1_1 + 1; else total2_1 = total2_1 + 1; end else if 1==A(i, 1) total1_2 = total1_2 + 1; else total2_2 = total2_2 + 1; end end end %end_for_i Confuse_Matrix = [total1_1 total1_2; total2_1, total2_2]; end function auc = roc_curve(deci,label_y, star_num, star_style) label_y(label_y~=1)=-1; [tmp,ind] = sort(deci,'descend'); roc_y = label_y(ind); stack_x = cumsum(roc_y == -1)/sum(roc_y == -1); stack_y = cumsum(roc_y == 1)/sum(roc_y == 1); auc = sum((stack_x(2:length(roc_y),1)-stack_x(1:length(roc_y)-1,1)).*stack_y(2:length(roc_y),1)); step = floor( size(stack_x, 1)/star_num ); select_indexes = 1:step:size(stack_x, 1); plot(stack_x(select_indexes, :), stack_y(select_indexes, :), star_style); set(gca,'FontName','Times New Roman','FontSize', 12) % , xlabel('False Positive Rate', 'FontName', 'Times New Roman', 'FontSize', 20); ylabel('True Positive Rate', 'FontName', 'Times New Roman', 'FontSize', 20); title('ROC Curve', 'FontName', 'Times New Roman', 'FontSize', 20); end

This matlab code can be used to draw a ROC curve with some stars.

좋은 웹페이지 즐겨찾기