Matlab 별이 있는 ROC 커브 그리기
6294 단어 Matlab 언어 기술생물정보학패턴 식별
function [AUC betterthreshold MCCMAXThreshold Confuse_Matrix MCCs FPRs TPRs F1S] = ROC_2_Class_OneFile_WithStar(begin_threshold, end_threshold, step, ismaxmcc,classificationFileName, star_num, star_style)
%
% The format of the file with path 'classificationFileName':
% truelabel proabilityOfFirstClass proabilityOfSecondClass
% 1 0.78 0.22
% 1 0.68 0.32
% 2 0.56 0.44
% 1 0.89 0.11
% 2 0.22 0.78
% .... .... ....
% ismaxmcc
% e.g., [AUC betterthreshold MCCMAXThreshold Confuse_Matrix MCCs FPRs TPRs F1S] = ROC_2_Class_OneFile_WithStar(0, 1, 0.01, 1, 'classificationFileName', 50, 'b--')
%
if nargin < 4, error('NNET:Arguments','Not enough input arguments.'); end
if ((begin_threshold>end_threshold) || step < 0.0)
error('NNET:Arguments','The pattern of a parameter is not right.');
end
A = load(classificationFileName);
[H L] = size(A);
maxTpr = 0;
maxFpr = 0;
balanceTpr = 0;
balanceFpr = 0;
MCCMAXThreshold = 0.5;
MCCPREThreshold = 0.5;
MCCs = [];
FPRs = [];
TPRs = [];
F1S = [];
tprs = [];
fprs = [];
th_fpr_5 = 0;
th_fpr_6 = 0;
th_fpr_7 = 0;
th_fpr_8 = 0;
th_fpr_9 = 0;
th_fpr_10 = 0;
maxMcc = -9999;
maxPre = -9999;
for threshold=begin_threshold:step:end_threshold
total1_1 = 0;
total1_2 = 0;
total2_1 = 0;
total2_2 = 0;
for i=1:H
if A(i, 2)>=threshold
if 1==A(i, 1)
total1_1 = total1_1 + 1;
else
total2_1 = total2_1 + 1;
end
else
if 1==A(i, 1)
total1_2 = total1_2 + 1;
else
total2_2 = total2_2 + 1;
end
end
end %end_for_i
tpr = 1.0*total1_1/(total1_1+total1_2);
fpr = 1.0*total2_1/(total2_1+total2_2);
pre = 1.0*total1_1/(total1_1+total2_1);
f1_score = (2*tpr*pre) / (tpr + tpr);
FPRs = [FPRs, fpr];
TPRs = [TPRs, tpr];
F1S = [F1S, f1_score];
if fpr>=0.05
if fpr<=0.055
th_fpr_5 = threshold;
end
end
if fpr>=0.06
if fpr<=0.065
th_fpr_6 = threshold;
end
end
if fpr>=0.07
if fpr<=0.075
th_fpr_7 = threshold;
end
end
if fpr>=0.08
if fpr<=0.085
th_fpr_8 = threshold;
end
end
if fpr>=0.09
if fpr<=0.095
th_fpr_9 = threshold;
end
end
if fpr>=0.10
if fpr<=0.105
th_fpr_10 = threshold;
end
end
tprs = [tprs, tpr];
fprs = [fprs, fpr];
%calculate mcc
mcc = (total1_1*total2_2 - total2_1*total1_2)/sqrt((total1_1+total2_1)*(total1_1+total1_2)*(total2_2+total2_1)*(total2_2+total1_2));
MCCs = [MCCs, mcc];
if mcc > maxMcc
maxMcc = mcc;
MCCMAXThreshold = threshold;
maxTpr = tpr;
maxFpr = fpr;
end
%calculate pre
pre = total1_1/(total1_1+total2_1);
if pre+mcc>maxPre+maxMcc
maxPre=pre;
MCCPREThreshold=threshold;
end
end
[L K] = size(tprs);
wc = 1;
tag = 0;
for x=1:K
if 0~=tprs(1, x) && 1~=tprs(1,x) && 0~=fprs(1, x)&& 1~=fprs(1,x) && wc>abs(tprs(1, x)+fprs(1,x)-1)
wc = abs(tprs(1, x)+fprs(1,x)-1);
tag = x;
balanceTpr = tprs(1, x);
balanceFpr = fprs(1,x);
end
end
betterthreshold = begin_threshold + tag*step;
AUC = roc_curve(A(:, 2), A(:, 1), star_num, star_style);
if 1 == ismaxmcc
% fprintf('Max Threshold = %.3f
', MCCMAXThreshold);
threshold = MCCMAXThreshold;
elseif 0 == ismaxmcc
% fprintf('Balance Threshold = %.3f
', betterthreshold);
threshold = betterthreshold;
elseif 5 == ismaxmcc
% fprintf('FPR=0.05 Threshold = %.3f
', th_fpr_5);
threshold = th_fpr_5;
elseif 6 == ismaxmcc
% fprintf('FPR=0.06 Threshold = %.3f
', th_fpr_6);
threshold = th_fpr_6;
elseif 7 == ismaxmcc
% fprintf('FPR=0.07 Threshold = %.3f
', th_fpr_7);
threshold = th_fpr_7;
elseif 8 == ismaxmcc
% fprintf('FPR=0.08 Threshold = %.3f
', th_fpr_8);
threshold = th_fpr_8;
elseif 9 == ismaxmcc
% fprintf('FPR=0.09 Threshold = %.3f
', th_fpr_9);
threshold = th_fpr_9;
elseif 10 == ismaxmcc
% fprintf('FPR=0.10 Threshold = %.3f
', th_fpr_10);
threshold = th_fpr_10;
end
threshold
total1_1 = 0;
total1_2 = 0;
total2_1 = 0;
total2_2 = 0;
for i=1:H
if A(i, 2)>=threshold
%disp(i);
if 1==A(i, 1)
total1_1 = total1_1 + 1;
else
total2_1 = total2_1 + 1;
end
else
if 1==A(i, 1)
total1_2 = total1_2 + 1;
else
total2_2 = total2_2 + 1;
end
end
end %end_for_i
Confuse_Matrix = [total1_1 total1_2; total2_1, total2_2];
end
function auc = roc_curve(deci,label_y, star_num, star_style)
label_y(label_y~=1)=-1;
[tmp,ind] = sort(deci,'descend');
roc_y = label_y(ind);
stack_x = cumsum(roc_y == -1)/sum(roc_y == -1);
stack_y = cumsum(roc_y == 1)/sum(roc_y == 1);
auc = sum((stack_x(2:length(roc_y),1)-stack_x(1:length(roc_y)-1,1)).*stack_y(2:length(roc_y),1));
step = floor( size(stack_x, 1)/star_num );
select_indexes = 1:step:size(stack_x, 1);
plot(stack_x(select_indexes, :), stack_y(select_indexes, :), star_style);
set(gca,'FontName','Times New Roman','FontSize', 12) % ,
xlabel('False Positive Rate', 'FontName', 'Times New Roman', 'FontSize', 20);
ylabel('True Positive Rate', 'FontName', 'Times New Roman', 'FontSize', 20);
title('ROC Curve', 'FontName', 'Times New Roman', 'FontSize', 20);
end
This matlab code can be used to draw a ROC curve with some stars.