Opencv 연구 노트:haartraining 프로그램의 cvCreate CARTClassifier 함수 설명 (CART 트리 모양 약분류기 생성) ~

cvCreate CARTClassifier 함수는haartraining 프로그램에서 CART 트리 모양의 약한 분류기를 만드는 데 사용되지만 일반적으로 단일 노드의 CART 분류기, 즉 말뚝 분류기만 사용하는데 다중 노드의 CART 분류기는 훈련에 많은 시간을 소모한다.자신의 테스트에 따르면 10분(2000정 샘플, 2000마이너스 샘플)을 기다려야 3개 노드의 약분류기를 훈련할 수 있다. 물론 전체적인 나무형 약분류기의 수량도 1/2 감소할 수 있다.이 함수를 꺼내서 이야기하는 이유는 인터넷에서 이 함수에 대한 상세한 설명을 찾을 수 없기 때문이다. 또한 CART의 응용이 매우 광범위하기 때문에 자신도 이 기회를 틈타 잘 배워서 자신의 이해를 여러분께 공유할 수 있기 때문이다.
1. CART 트리의 디자인 문제, 즉 CvCARTClassifier라는 구조체를 먼저 말하자면 구조체에서 변수의 의미가 나를 골치 아프게 한다.이제 변수의 의미가 다음과 같이 추가됩니다.
4
typedef struct CvCARTClassifier
{
    CV_CLASSIFIER_FIELDS()
    /* number of internal nodes */
    int count;                      //        

    /* internal nodes (each array of <count> elements) */
    int* compidx;                   //         Haar    
    float* threshold;               //         Haar    
    int* left;                      //             (       ,        )
    int* right;                     //             (       ,        )

    /* leaves (array of <count>+1 elements) */
    float* val;                     //          
} CvCARTClassifier;
그 중에서count는main 주함수 중의 매개 변수 nsplits로 비엽 노드 수를 정의하거나 중간 노드 수라고 부른다.개인적으로 이렇게 나무 한 그루를 설계하는 것은 매우 과학적이고 비엽자 노드와 엽자 노드를 분리하여 표현하는데 구조체가 매우 간결하다. 단지 당시left의 진실한 의미는 나로 하여금 오랫동안 궁리하게 했다.
2. cvCreatecartClassifier에서 노드가 분열된'후보 속성 집합'은 여전히Haar특징이고'분류 준칙'은 분류 오류율(error)의 하락 정도이다. 이 함수에서 분류 준칙의 구체적인 도량은 부모 노드의 하위 노드 자체의 error에서 이 하위 노드의 두 하위 노드의 error를 빼고 이 변수는 코드에서errdrop로 표시한다.
3. CART 트리 분류기의 형식은 다양하다. 3개의 비엽자 노드에 대해 말하자면 내가 디버깅을 한 후에 다음과 같은 두 가지 약분류기를 만났다.
4. 그러나 이 함수의 노드 첨가 방식은 일반적인 카트 트리와 다르기 때문에 좌우 두 개의 서브 노드를 분열시켜 더욱 선별하는 과정이 많다.이 점에 대해 나는 아직 작가의 의도를 꿰뚫어 보지 못했다. 내 말은, 만약 선별하지 않는다면, 무엇이 타당하지 않겠는가?
5. 어린이 신발은 왜 나무 모양의 약분류기를 사용하느냐고 물어볼 수 있다. 나의 이해는 한 나무 모양의 분류기는 테스트 과정에서 특징을 비교하는 횟수가 상대적으로 직렬적인 약분류기보다 훨씬 적다는 것이다. 예를 들어 세 개의 직렬의Haar특징은 비교 횟수가 세 번이지만 3개 노드의 CART 나무라면 비교 횟수는 두 번밖에 안 될 수 있다.또한 나무형 약분류기에서 하위 노드가 겨냥한 데이터 집합이 더욱 구체적이고 목적성이 있어 정밀도가 더욱 높을 수 있다.
이상은 cvCreatecARTClassifier 함수에 대한 자신의 이해입니다. 주석이 있는 원본 코드는 다음과 같습니다.
(전재:http://blog.csdn.net/wsj998689aa/article/details/43411809)
CV_BOOST_IMPL
CvClassifier* cvCreateCARTClassifier( CvMat* trainData,                     //          
                                     int flags,                             //       
                                     CvMat* trainClasses,                   //         
                                     CvMat* typeMask,           
                                     CvMat* missedMeasurementsMask,
                                     CvMat* compIdx,                        //       
                                     CvMat* sampleIdx,                      //       
                                     CvMat* weights,                        //       
                                     CvClassifierTrainParams* trainParams ) //   
{
    CvCARTClassifier* cart = NULL;          // CART      
    size_t datasize = 0;
    int count = 0;                          // CART      
    int i = 0;
    int j = 0;

    CvCARTNode* intnode = NULL;             // CART  
    CvCARTNode* list = NULL;                //       
    int listcount = 0;                      //       
    CvMat* lidx = NULL;                     //         
    CvMat* ridx = NULL;                     //         

    float maxerrdrop = 0.0F;
    int idx = 0;

    //           
    void (*splitIdxCallback)( int compidx, float threshold,
        CvMat* idx, CvMat** left, CvMat** right,
        void* userdata );
    void* userdata;

    //          
    count = ((CvCARTTrainParams*) trainParams)->count;

    assert( count > 0 );

    datasize = sizeof( *cart ) + (sizeof( float ) + 3 * sizeof( int )) * count + 
        sizeof( float ) * (count + 1);

    cart = (CvCARTClassifier*) cvAlloc( datasize );
    memset( cart, 0, datasize );

    cart->count = count;

    //            
    cart->eval = cvEvalCARTClassifier;
    
    cart->save = NULL;
    cart->release = cvReleaseCARTClassifier;

    cart->compidx = (int*) (cart + 1);                      //         Haar    
    cart->threshold = (float*) (cart->compidx + count);     //         Haar    
    cart->left  = (int*) (cart->threshold + count);         //       ,        
    cart->right = (int*) (cart->left + count);              //       ,        
    cart->val = (float*) (cart->right + count);             //          

    datasize = sizeof( CvCARTNode ) * (count + count);
    intnode = (CvCARTNode*) cvAlloc( datasize );
    memset( intnode, 0, datasize );
    list = (CvCARTNode*) (intnode + count);

    //         ,   icvSplitIndicesCallback  
    splitIdxCallback = ((CvCARTTrainParams*) trainParams)->splitIdx;
    userdata = ((CvCARTTrainParams*) trainParams)->userdata;

    // R        ,C        
    if( splitIdxCallback == NULL )
    {
        splitIdxCallback = ( CV_IS_ROW_SAMPLE( flags ) )
            ? icvDefaultSplitIdx_R : icvDefaultSplitIdx_C;
        userdata = trainData;
    }

    //   CART   
    intnode[0].sampleIdx = sampleIdx;
    intnode[0].stump = (CvStumpClassifier*)
        ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
        trainClasses, typeMask, missedMeasurementsMask, compIdx, sampleIdx, weights,
        ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
    cart->left[0] = cart->right[0] = 0;

    //         ,lerror  rerror  0             
    listcount = 0;
    for( i = 1; i < count; i++ )
    {
        //             ,         
        //   lidx ridx,                 
        splitIdxCallback( intnode[i-1].stump->compidx, intnode[i-1].stump->threshold,
            intnode[i-1].sampleIdx, &lidx, &ridx, userdata );

        //                  
        if( intnode[i-1].stump->lerror != 0.0F )
        {
            //          
            list[listcount].sampleIdx = lidx;

            //              
            list[listcount].stump = (CvStumpClassifier*)
                ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
                trainClasses, typeMask, missedMeasurementsMask, compIdx,
                list[listcount].sampleIdx,
                weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );

            //       (   error     )
            list[listcount].errdrop = intnode[i-1].stump->lerror
                - (list[listcount].stump->lerror + list[listcount].stump->rerror);
            list[listcount].leftflag = 1;
            list[listcount].parent = i-1;
            listcount++;
        }
        else
        {
            cvReleaseMat( &lidx );
        }

        //   ,        ,      
        if( intnode[i-1].stump->rerror != 0.0F )
        {
            list[listcount].sampleIdx = ridx;
            list[listcount].stump = (CvStumpClassifier*)
                ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
                trainClasses, typeMask, missedMeasurementsMask, compIdx,
                list[listcount].sampleIdx,
                weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
            list[listcount].errdrop = intnode[i-1].stump->rerror
                - (list[listcount].stump->lerror + list[listcount].stump->rerror);
            list[listcount].leftflag = 0;       //            
            list[listcount].parent = i-1;
            listcount++;
        }
        else
        {
            cvReleaseMat( &ridx );
        }

        if( listcount == 0 ) break;

        idx = 0;
        maxerrdrop = list[idx].errdrop;
        for( j = 1; j < listcount; j++ )
        {
            if( list[j].errdrop > maxerrdrop )
            {
                idx = j;	//       ,         idx   
                maxerrdrop = list[j].errdrop;
            }
        }

        //      
        intnode[i] = list[idx];

        //                 
        if( list[idx].leftflag )
        {
            cart->left[list[idx].parent] = i;
        }
        else
        {
            cart->right[list[idx].parent] = i;
        }

        //       ,             ,                
        if( idx != (listcount - 1) )
        {
            list[idx] = list[listcount - 1];
        }
        listcount--;
    }

    //                   、               
    // left right    0, 0      
    //   CART       ,          
    j = 0;
    cart->count = 0;
    for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
    {
        cart->count++;
        cart->compidx[i] = intnode[i].stump->compidx;	// haar     
        cart->threshold[i] = intnode[i].stump->threshold;

        //                
        if( cart->left[i] <= 0 )
        {
            cart->left[i] = -j;
            cart->val[j] = intnode[i].stump->left;      //   left float ,  CVMat*
            j++;
        }
        if( cart->right[i] <= 0 )
        {
            cart->right[i] = -j;
            cart->val[j] = intnode[i].stump->right;
            j++;
        }
    }

    //     
    for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
    {
        intnode[i].stump->release( (CvClassifier**) &(intnode[i].stump) );
        if( i != 0 )
        {
            cvReleaseMat( &(intnode[i].sampleIdx) );
        }
    }
    for( i = 0; i < listcount; i++ )
    {
        list[i].stump->release( (CvClassifier**) &(list[i].stump) );
        cvReleaseMat( &(list[i].sampleIdx) );
    }

    cvFree( &intnode );

    return (CvClassifier*) cart;
}

이 프로그램의 일부 세부 사항을 제가 정확하게 이해하지 못했을 수도 있습니다. 예를 들어 좌우 지점의 error가 동시에 0이 되지 않을 때 제 설명은 프로그램이 오른쪽 지점의 우선순위를 좀 더 높게 설정한 것입니다. 바로 틀릴 수 있는 부분입니다. 그리고 어린이 신발들과 함께 연구하고 싶습니다. 감사합니다!
후기: 위의 그 결론은 확실히 틀렸다. 좌우분지의 우선순위는 같다.이 함수는 나중에 다시 한 번 봤는데 코드에 대한 새로운 인식이 생긴 것 같아서 제가 전에 범한 오류를 다시 수정하고 여러분께 계속 공유해 드리겠습니다. 여러분들이 계속 저의 결점을 지적해 주셔서 감사합니다. 감사하기 그지없습니다!

좋은 웹페이지 즐겨찾기