Shark 소스 코드 분석 (11): 랜 덤 숲 알고리즘

Shark 소스 코드 분석 (11): 랜 덤 숲 알고리즘
이 알고리즘 에 대한 소 개 는 이전에 통합 알고리즘 에 관 한 블 로 그 를 참조 할 수 있 습 니 다.Shark 에 서 는 의사 결정 트 리 알고리즘 에 대해 CART 알고리즘 만 실 현 했 기 때문에 랜 덤 숲 알고리즘 에 도 CART 알고리즘 만 포함 되 어 있 습 니 다.만약 당신 이 내 가 이전에 쓴 CART 알고리즘 소스 코드 분석 에 관 한 블 로 그 를 보 았 다 면, 뒤에 보면 랜 덤 숲 알고리즘 과 의 코드 가 많 지 않다 는 것 을 알 수 있 을 것 이다.다만 가장 좋 은 속성 을 선택 할 때 무 작위 로 후보 집 을 선택 하 는 과정 이 하나 더 생 겼 다.이것 도 랜 덤 숲 알고리즘 의 큰 특징 이다.CART 알고리즘 은 분류 작업 에 도 사용 할 수 있 고 복귀 작업 에 도 사용 할 수 있 기 때문에 CART 알고리즘 을 기반 으로 한 랜 덤 숲 도 이 두 작업 에 사용 할 수 있다.분류 작업 에 사용 할 코드 만 소개 합 니 다.
MeanModel 클래스
이 종 류 는 통합 알고리즘 의 기본 클래스 라 고 할 수 있 으 며, 여러 개의 기본 학습 기의 출력 결 과 를 어떻게 종합 하 는 지 를 나타 낸다.이 종류의 정 의 는 에 있다.
template<class ModelType> // ModelType         
class MeanModel : public AbstractModel<typename ModelType::InputType, typename ModelType::OutputType>
{
private:
    typedef AbstractModel<typename ModelType::InputType, typename ModelType::OutputType> base_type;
public:

    MeanModel():m_weightSum(0){}

    std::string name() const
    { return "MeanModel"; }

    using base_type::eval;
    //           ,           ,             
    void eval(typename base_type::BatchInputType const& patterns, typename base_type::BatchOutputType& outputs)const{
        m_models[0].eval(patterns,outputs);
        outputs *=m_weight[0];
        for(std::size_t i = 1; i != m_models.size(); i++) 
            noalias(outputs) += m_weight[i] * m_models[i](patterns);
        outputs /= m_weightSum;
    }

    void eval(typename base_type::BatchInputType const& patterns, typename base_type::BatchOutputType& outputs, State& state)const{
        eval(patterns,outputs);
    }

    RealVector parameterVector() const {
        return RealVector();
    }

    void setParameterVector(const RealVector& param) {
        SHARK_ASSERT(param.size() == 0);
    }

    void read(InArchive& archive){
        archive >> m_models;
        archive >> m_weight;
        archive >> m_weightSum;
    }

    void write(OutArchive& archive)const{
        archive << m_models;
        archive << m_weight;
        archive << m_weightSum;
    }

    void clearModels(){
        m_models.clear();
        m_weight.clear();
        m_weightSum = 0.0;
    }

    //         
    void addModel(ModelType const& model, double weight = 1.0){
        SHARK_CHECK(weight > 0, "Weights must be positive");
        m_models.push_back(model);
        m_weight.push_back(weight);
        m_weightSum+=weight;
    }

    double const& weight(std::size_t i)const{
        return m_weight[i];
    }

    void setWeight(std::size_t i, double newWeight){
        m_weightSum=newWeight - m_weight[i];
        m_weight[i] = newWeight;
    }

    std::size_t numberOfModels()const{
        return m_models.size();
    }

protected:
    //          ,             ,         ,         
    std::vector m_models;

    //            
    std::vector<double> m_weight;

    //       
    double m_weightSum;
};

RFclassifier 클래스
이 종 류 는 무 작위 숲 을 나타 내 는 것 으로 에 정의 된다.
class RFClassifier : public MeanModel >
{
public:
    std::string name() const
    { return "RFClassifier"; }

    //        OOB  ,      OOB      ,          
    void computeOOBerror(){
        std::size_t n_trees = numberOfModels();
        m_OOBerror = 0;
        for(std::size_t j=0;j!=n_trees;++j){
            m_OOBerror += m_models[j].OOBerror();
        }
        m_OOBerror /= n_trees;
    }

    //               ,               
    void computeFeatureImportances(){
        m_featureImportances.resize(m_inputDimension);
        std::size_t n_trees = numberOfModels();

        for(std::size_t i=0;i!=m_inputDimension;++i){
            m_featureImportances[i] = 0;
            for(std::size_t j=0;j!=n_trees;++j){
                m_featureImportances[i] += m_models[j].featureImportances()[i];
            }
            m_featureImportances[i] /= n_trees;
        }
    }

    double const OOBerror() const {
        return m_OOBerror;
    }

    RealVector const& featureImportances() const {
        return m_featureImportances;
    }

    //            ,                     
    UIntVector countAttributes() const {
        std::size_t n = m_models.size();
        if(!n) return UIntVector();
        UIntVector r = m_models[0].countAttributes();
        for(std::size_t i=1; i< n; i++ ) {
            noalias(r) += m_models[i].countAttributes();
        }
        return r;
    }

    void setLabelDimension(std::size_t in){
        m_labelDimension = in;
    }

    void setInputDimension(std::size_t in){
        m_inputDimension = in;
    }

    typedef CARTClassifier::TreeType TreeType;
    typedef std::vector ForestInfo; //      

    //            
    ForestInfo getForestInfo() const {
        ForestInfo finfo(m_models.size());
        for (std::size_t i=0; ireturn finfo;
    }

    //  finfo   m_model,weights  m_weight
    void setForestInfo(ForestInfo const& finfo, std::vector<double> const& weights = std::vector<double>()) {
        std::size_t n_tree = finfo.size();
        std::vector<double> we(weights);
        m_models.resize(n_tree);
        if (weights.empty()) // set default weights to 1
            we.resize(n_tree, 1);
        else if (weights.size() != n_tree)
            throw SHARKEXCEPTION("Weights must be the same number as trees");

        for (std::size_t i=0; iprotected:
    //             ,             
    std::size_t m_labelDimension;

    std::size_t m_inputDimension;

    // out-of-bag  
    double m_OOBerror;

    //          
    RealVector m_featureImportances;

};

RFtrainer 클래스
랜 덤 숲 을 만 들 때 는 보통 100 그루 이상 의 의사 결정 트 리 를 구축한다.각 결정 트 리 의 수출 결 과 를 결합 할 때 분류 문제 에 대해 다수 투표 법 을 사용 하고 복귀 임무 에 대해 평균 법 을 사용한다.
의사 결정 트 리 의 규모 가 정지 조건 으로 늘 어 나 면 의사 결정 트 리 의 구축 이 멈춘다.이곳 에 서 는 결정 트 리 에 대해 가 지 를 자 를 필요 가 없다.Bagging 통합 방법 은 편차 에 대한 통제 가 잘 되 기 때문이다.
이러한 정 의 는 에 있 고 에 실현 된다.이 클래스 에는 분류 와 회귀 에 대한 코드 가 포함 되 어 있 으 며, 분류 에 관 한 코드 만 소개 합 니 다.
class RFTrainer 
: public AbstractTrainerunsigned int>
, public AbstractTrainer,
  public IParameterizable
{

public:
    SHARK_EXPORT_SYMBOL RFTrainer(bool computeFeatureImportances = false, bool computeOOBerror = false){
        m_try = 0;
        m_B = 0;
        m_nodeSize = 0;
        m_OOBratio = 0;
        m_regressionLearner = false;
        m_computeFeatureImportances = computeFeatureImportances;
        m_computeOOBerror = computeOOBerror;
    }

    std::string name() const
    { return "RFTrainer"; }

    //              
    SHARK_EXPORT_SYMBOL void train(RFClassifier& model, ClassificationDataset const& dataset){
        model.clearModels();

        m_inputDimension = inputDimension(dataset);

        model.setInputDimension(m_inputDimension);
        model.setLabelDimension(numberOfClasses(dataset));

        m_maxLabel = static_cast<unsigned int>(numberOfClasses(dataset))-1;

        m_regressionLearner = false;
        setDefaults();

        std::size_t subsetSize = static_cast<std::size_t>(dataset.numberOfElements()*m_OOBratio);
        DataViewconst> elements(dataset);

        //Generate m_B trees
        SHARK_PARALLEL_FOR(int i = 0; i < (int)m_B; ++i){
            //             ,              ,        
            std::vector<std::size_t> subsetIndices(dataset.numberOfElements());
            boost::iota(subsetIndices,0);
            boost::random_shuffle(subsetIndices);

            // create oob indices
            std::vector<std::size_t>::iterator oobStart = subsetIndices.begin() + subsetSize;
            std::vector<std::size_t>::iterator oobEnd   = subsetIndices.end();

            //               
            subsetIndices.erase(oobStart, oobEnd);
            ClassificationDataset dataTrain = toDataset(subset(elements,subsetIndices));

            //Create attribute tables
            boost::unordered_map<std::size_t, std::size_t> cAbove;
            AttributeTables tables;
            createAttributeTables(dataTrain.inputs(), tables);
            createCountMatrix(dataTrain, cAbove);

            CARTClassifier::TreeType tree = buildTree(tables, dataTrain, cAbove, 0);
            CARTClassifier cart(tree, m_inputDimension); //       

            // if oob error or importances have to be computed, create an oob sample
            if(m_computeOOBerror || m_computeFeatureImportances){
                std::vector<std::size_t> subsetIndicesOOB(oobStart, oobEnd);
                ClassificationDataset dataOOB = toDataset(subset(elements, subsetIndicesOOB));

                //            ,             ,                    
                if(m_computeFeatureImportances){
                    cart.computeFeatureImportances(dataOOB);
                }
                else{
                    cart.computeOOBerror(dataOOB);
                }
            }

            SHARK_CRITICAL_REGION{
                model.addModel(cart); //               
            }
        }

        if(m_computeOOBerror){
            model.computeOOBerror();
        }

        if(m_computeFeatureImportances){
            model.computeFeatureImportances();
        }
    }

    /// Train a random forest for regression.
    SHARK_EXPORT_SYMBOL void train(RFClassifier& model, RegressionDataset const& dataset);

    /// Set the number of random attributes to investigate at each node.
    SHARK_EXPORT_SYMBOL void setMTry(std::size_t mtry);

    /// Set the number of trees to grow.
    SHARK_EXPORT_SYMBOL void setNTrees(std::size_t nTrees);

    /// Controls when a node is considered pure. If set to 1, a node is pure
    /// when it only consists of a single node.
    SHARK_EXPORT_SYMBOL void setNodeSize(std::size_t nTrees);

    /// Set the fraction of the original training dataset to use as the
    /// out of bag sample. The default value is 0.66.
    SHARK_EXPORT_SYMBOL void setOOBratio(double ratio);

    /// Return the parameter vector.
    RealVector parameterVector() const
    {
        RealVector ret(1); // number of trees
        init(ret) << (double)m_B;
        return ret;
    }

    /// Set the parameter vector.
    void setParameterVector(RealVector const& newParameters)
    {
        SHARK_ASSERT(newParameters.size() == numberOfParameters());
        setNTrees((size_t) newParameters[0]);
    }

protected:
    struct RFAttribute {
        double value; //            
        std::size_t id; //             
    };

    //                     
    typedef std::vector < RFAttribute > AttributeTable;
    //                     
    typedef std::vector < AttributeTable > AttributeTables;

    //                       
    SHARK_EXPORT_SYMBOL void createAttributeTables(Data const& dataset, AttributeTables& tables){
        std::size_t elements = dataset.numberOfElements();
        //Each entry in the outer vector is an attribute table
        AttributeTable table;
        RFAttribute a;
        for(std::size_t j=0; jfor(std::size_t i=0; i//Store Attribute value, class and rid
                a.value = dataset.element(i)[j];
                a.id = i;
                table.push_back(a);
            }
            std::sort(table.begin(), table.end(), tableSort); // vector   ,              
            tables.push_back(table);
        }
    }

    //              ,   CAbove 
    SHARK_EXPORT_SYMBOL void createCountMatrix(ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove){
        std::size_t elements = dataset.numberOfElements();
        for(std::size_t i = 0 ; i < elements; i++){
            cAbove[dataset.element(i).label]++;
        }
    }

    // Split attribute tables into left and right parts.
    SHARK_EXPORT_SYMBOL void splitAttributeTables(AttributeTables const& tables, std::size_t index, std::size_t valIndex, AttributeTables& LAttributeTables, AttributeTables& RAttributeTables){
        AttributeTable table;

        //               ,           
        boost::unordered_map<std::size_t, bool> hash; //                             
        for(std::size_t i = 0; i< tables[index].size(); i++){
            hash[tables[index][i].id] = (i<=valIndex);
        }

        for(std::size_t i = 0; i < tables.size(); i++){
            LAttributeTables.push_back(table);
            RAttributeTables.push_back(table);
            for(std::size_t j = 0; j < tables[i].size(); j++){
                if(hash[tables[i][j].id]){
                    //Left
                    LAttributeTables[i].push_back(tables[i][j]);
                }else{
                    //Right
                    RAttributeTables[i].push_back(tables[i][j]);
                }
            }
        }
    }

    //        
    SHARK_EXPORT_SYMBOL CARTClassifier::TreeType buildTree(AttributeTables& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId){
        CARTClassifier::TreeType lTree, rTree; //              

        CARTClassifier::NodeInfo nodeInfo; //            

        nodeInfo.nodeId = nodeId;
        nodeInfo.attributeIndex = 0;
        nodeInfo.attributeValue = 0.0;
        nodeInfo.leftNodeId = 0;
        nodeInfo.rightNodeId = 0;
        nodeInfo.misclassProp = 0.0;
        nodeInfo.r = 0;
        nodeInfo.g = 0.0;

        std::size_t n = tables[0].size(); //                 

        bool isLeaf = false;
        if(gini(cAbove,tables[0].size())==0 || n <= m_nodeSize){
            isLeaf = true;
        }else{
            boost::unordered_map<std::size_t, std::size_t> cBelow, cBestBelow, cBestAbove;

            set<std::size_t> tableIndicies; //                  
            generateRandomTableIndicies(tableIndicies);

            std::size_t bestAttributeIndex, bestAttributeValIndex;

            double bestAttributeVal;
            double bestImpurity = n+1.0;

            //         
            for (set<std::size_t>::iterator it=tableIndicies.begin() ; it != tableIndicies.end(); it++ ){
                std::size_t attributeIndex = *it;
                boost::unordered_map<std::size_t, std::size_t> cTmpAbove = cAbove;
                cBelow.clear();

                //       ,       
                for(std::size_t i=1; istd::size_t prev = i-1;

                    //Update the count of the label
                    cBelow[dataset.element(tables[attributeIndex][prev].id).label]++;
                    cTmpAbove[dataset.element(tables[attributeIndex][prev].id).label]--;

                    //                            
                    if(tables[attributeIndex][prev].value!=tables[attributeIndex][i].value){
                        std::size_t n1 = i;
                        std::size_t n2 = n-n1;

                        //Calculate the Gini impurity of the split
                        double impurity = n1*gini(cBelow,n1)+n2*gini(cTmpAbove,n2);
                        if(impurity//Found a more pure split, store the attribute index and value
                            bestImpurity = impurity;
                            bestAttributeIndex = attributeIndex;
                            bestAttributeValIndex = prev;
                            bestAttributeVal = tables[attributeIndex][bestAttributeValIndex].value;
                            cBestAbove = cTmpAbove;
                            cBestBelow = cBelow;
                        }
                    }
                }
            }

            //                  ,       
            if(bestImpurity1){
                AttributeTables rTables, lTables;
                splitAttributeTables(tables, bestAttributeIndex, bestAttributeValIndex, lTables, rTables);
                tables.clear();

                nodeInfo.attributeIndex = bestAttributeIndex;
                nodeInfo.attributeValue = bestAttributeVal;
                nodeInfo.leftNodeId = 2*nodeId+1; //               
                nodeInfo.rightNodeId = 2*nodeId+2;

                lTree = buildTree(lTables, dataset, cBestBelow, nodeInfo.leftNodeId);
                rTree = buildTree(rTables, dataset, cBestAbove, nodeInfo.rightNodeId);
            }else{
                //         ,        
                isLeaf = true;
            }
        }

        CARTClassifier::TreeType tree;

        if(isLeaf){
            //        ,       ,      
            nodeInfo.label = hist(cAbove);
            tree.push_back(nodeInfo);
            return tree;
        }

        //        ,            ,              
        tree.push_back(nodeInfo);
        tree.insert(tree.end(), lTree.begin(), lTree.end());
        tree.insert(tree.end(), rTree.begin(), rTree.end());

        return tree;
    }

    /// Builds a decision tree for regression
    SHARK_EXPORT_SYMBOL CARTClassifier::TreeType buildTree(AttributeTables& tables, RegressionDataset const& dataset, std::vector const& labels, std::size_t nodeId);

    /// comparison function for sorting an attributeTable
    SHARK_EXPORT_SYMBOL static bool tableSort(RFAttribute const& v1, RFAttribute const& v2);

    //                   
    SHARK_EXPORT_SYMBOL RealVector hist(boost::unordered_map<std::size_t, std::size_t> countMatrix){

        RealVector histogram(m_maxLabel+1,0.0);

        std::size_t totalElements = 0;

        boost::unordered_map<std::size_t, std::size_t>::iterator it;
        for ( it=countMatrix.begin() ; it != countMatrix.end(); it++ ){
            histogram(it->first) = (double)it->second;
            totalElements += it->second;
        }
        histogram /= totalElements;

        return histogram;
    }

    /// Average label over a vector.
    SHARK_EXPORT_SYMBOL RealVector average(std::vector const& labels);

    //        gini  
    SHARK_EXPORT_SYMBOL double gini(boost::unordered_map<std::size_t, std::size_t> & countMatrix, std::size_t n){
        double res = 0;
        boost::unordered_map<std::size_t, std::size_t>::iterator it;
        if(n){
            n = n*n;
            for ( it=countMatrix.begin() ; it != countMatrix.end(); it++ ){
                res += sqr(it->second)/(double)n;
            }
        }
        return 1-res;
    }

    /// Total Sum Of Squares
    SHARK_EXPORT_SYMBOL double totalSumOfSquares(std::vector& labels, std::size_t from, std::size_t to, RealVector const& sumLabel);

    //              ,   tableIndicies 
    SHARK_EXPORT_SYMBOL void generateRandomTableIndicies(std::set<std::size_t>& tableIndicies){
        while(tableIndicies.size()//               
    SHARK_EXPORT_SYMBOL void setDefaults(){
        if(!m_try){
            if(m_regressionLearner){
                setMTry(static_cast<std::size_t>(std::ceil(m_inputDimension/3.0)));
            }else{
                //                                  
                setMTry(static_cast<std::size_t>(std::ceil(std::sqrt((double)m_inputDimension))));
            }
        }

        if(!m_B){
            //             100 
            setNTrees(100);
        }

        if(!m_nodeSize){
            if(m_regressionLearner){
                setNodeSize(5);
            }else{
                //                 1,             
                setNodeSize(1);
            }
        }

        if(m_OOBratio <= 0 || m_OOBratio>1){
            //           ,             0.66,   bootstarp  ,      bootstrap     
            setOOBratio(0.66);
        }
    }

    std::size_t m_inputDimension;

    std::size_t m_labelDimension;

    //        
    unsigned int m_maxLabel;

    //            ,      
    std::size_t m_try;

    //      ,      
    std::size_t m_B;

    //                 
    std::size_t m_nodeSize;

    //          ,            
    double m_OOBratio;

    //             
    bool m_regressionLearner;

    //     ,      m_FeatureImportances
    bool m_computeFeatureImportances;

    //     ,      m_OOBerror
    bool m_computeOOBerror;
};

하나의 예
#include  //importing the file
#include  //the random forest trainer
#include  //zero one loss for evaluation

#include  

using namespace std; 
using namespace shark;

int main() {

    //*****************LOAD AND PREPARE DATA***********************//
    //Read Sample data set C.csv

    ClassificationDataset data;
    importCSV(data, "data/C.csv", LAST_COLUMN, ' ');

    //Split the dataset into a training and a test dataset
    ClassificationDataset dataTest = splitAtElement(data,311);

    cout << "Training set - number of data points: " << data.numberOfElements()
        << " number of classes: " << numberOfClasses(data)
        << " input dimension: " << inputDimension(data) << endl;

    cout << "Test set - number of data points: " << dataTest.numberOfElements()
        << " number of classes: " << numberOfClasses(dataTest)
        << " input dimension: " << inputDimension(dataTest) << endl;

    //Generate a random forest
    RFTrainer trainer;
    RFClassifier model;
    trainer.train(model, data);

    // evaluate Random Forest classifier
    ZeroOneLoss<unsigned int, RealVector> loss;
    Data prediction = model(data.inputs());
    cout << "Random Forest on training set accuracy: " << 1. - loss.eval(data.labels(), prediction) << endl;

    prediction = model(dataTest.inputs());
    cout << "Random Forest on test set accuracy:     " << 1. - loss.eval(dataTest.labels(), prediction) << endl;
}

좋은 웹페이지 즐겨찾기