37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
51 #include "functorexpression.hxx"
52 #include "random_forest/rf_common.hxx"
53 #include "random_forest/rf_nodeproxy.hxx"
54 #include "random_forest/rf_split.hxx"
55 #include "random_forest/rf_decisionTree.hxx"
56 #include "random_forest/rf_visitors.hxx"
57 #include "random_forest/rf_region.hxx"
58 #include "sampling.hxx"
59 #include "random_forest/rf_preprocessing.hxx"
60 #include "random_forest/rf_online_prediction_set.hxx"
61 #include "random_forest/rf_earlystopping.hxx"
62 #include "random_forest/rf_ridge_split.hxx"
82 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
84 SamplerOptions return_opt;
86 return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
142 template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
149 typedef detail::DecisionTree DecisionTree_t;
156 typedef LabelType LabelT;
231 template<
class TopologyIterator,
class ParameterIterator>
233 TopologyIterator topology_begin,
234 ParameterIterator parameter_begin,
238 trees_(treeCount, DecisionTree_t(problem_spec)),
239 ext_param_(problem_spec),
242 for(
unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
244 trees_[k].topology_ = *topology_begin;
245 trees_[k].parameters_ = *parameter_begin;
264 vigra_precondition(ext_param_.used() ==
true,
265 "RandomForest::ext_param(): "
266 "Random forest has not been trained yet.");
282 vigra_precondition(ext_param_.used() ==
false,
283 "RandomForest::set_ext_param():"
284 "Random forest has been trained! Call reset()"
285 "before specifying new extrinsic parameters.");
309 DecisionTree_t
const &
tree(
int index)
const
311 return trees_[index];
316 DecisionTree_t &
tree(
int index)
318 return trees_[index];
328 return ext_param_.column_count_;
339 return ext_param_.column_count_;
347 return ext_param_.class_count_;
354 return options_.tree_count_;
359 template<
class U,
class C1,
372 bool adjust_thresholds=
false);
374 template <
class U,
class C1,
class U2,
class C2>
379 onlineLearn(features,
389 template<
class U,
class C1,
395 void reLearnTree(MultiArrayView<2,U,C1>
const & features,
396 MultiArrayView<2,U2,C2>
const & response,
403 template<
class U,
class C1,
class U2,
class C2>
404 void reLearnTree(MultiArrayView<2, U, C1>
const & features,
405 MultiArrayView<2, U2, C2>
const & labels,
408 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
453 template <
class U,
class C1,
459 void learn( MultiArrayView<2, U, C1>
const & features,
460 MultiArrayView<2, U2,C2>
const & response,
464 Random_t
const & random);
466 template <
class U,
class C1,
471 void learn( MultiArrayView<2, U, C1>
const & features,
472 MultiArrayView<2, U2,C2>
const & response,
478 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
487 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
488 void learn( MultiArrayView<2, U, C1>
const & features,
489 MultiArrayView<2, U2,C2>
const & labels,
499 template <
class U,
class C1,
class U2,
class C2,
500 class Visitor_t,
class Split_t>
501 void learn( MultiArrayView<2, U, C1>
const & features,
502 MultiArrayView<2, U2,C2>
const & labels,
531 template <
class U,
class C1,
class U2,
class C2>
559 template <
class U,
class C,
class Stop>
562 template <
class U,
class C>
573 template <
class U,
class C>
574 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features,
575 ArrayVectorView<double> prior)
const;
585 template <
class U,
class C1,
class T,
class C2>
589 vigra_precondition(features.
shape(0) == labels.
shape(0),
590 "RandomForest::predictLabels(): Label array has wrong size.");
591 for(
int k=0; k<features.
shape(0); ++k)
595 template <
class U,
class C1,
class T,
class C2,
class Stop>
600 vigra_precondition(features.
shape(0) == labels.
shape(0),
601 "RandomForest::predictLabels(): Label array has wrong size.");
602 for(
int k=0; k<features.
shape(0); ++k)
613 template <
class U,
class C1,
class T,
class C2,
class Stop>
615 MultiArrayView<2, T, C2> & prob,
617 template <
class T1,
class T2,
class C>
619 MultiArrayView<2, T2, C> & prob);
627 template <
class U,
class C1,
class T,
class C2>
634 template <
class U,
class C1,
class T,
class C2>
644 template <
class LabelType,
class PreprocessorTag>
645 template<
class U,
class C1,
651 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1>
const & features,
652 MultiArrayView<2,U2,C2>
const & response,
658 bool adjust_thresholds)
660 online_visitor_.activate();
661 online_visitor_.adjust_thresholds=adjust_thresholds;
665 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
666 typedef UniformIntRandomFunctor<Random_t>
673 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
674 Default_Stop_t default_stop(options_);
675 typename RF_CHOOSER(Stop_t)::type stop
676 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
677 Default_Split_t default_split;
678 typename RF_CHOOSER(Split_t)::type split
679 = RF_CHOOSER(Split_t)::choose(split_, default_split);
680 rf::visitors::StopVisiting stopvisiting;
681 typedef rf::visitors::detail::VisitorNode
682 <rf::visitors::OnlineLearnVisitor,
683 typename RF_CHOOSER(Visitor_t)::type>
686 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
693 ext_param_.class_count_=0;
694 Preprocessor_t preprocessor( features, response,
695 options_, ext_param_);
698 RandFunctor_t randint ( random);
701 split.set_external_parameters(ext_param_);
702 stop.set_external_parameters(ext_param_);
706 PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
712 for(
int ii = 0; ii < (int)trees_.size(); ++ii)
714 online_visitor_.tree_id=ii;
715 poisson_sampler.sample();
716 std::map<int,int> leaf_parents;
717 leaf_parents.clear();
719 for(
int s=0;s<poisson_sampler.numOfSamples();++s)
721 int sample=poisson_sampler[s];
722 online_visitor_.current_label=preprocessor.response()(sample,0);
723 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
724 int leaf=trees_[ii].getToLeaf(
rowVector(features,sample),online_visitor_);
728 online_visitor_.add_to_index_list(ii,leaf,sample);
731 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
733 leaf_parents[leaf]=online_visitor_.last_node_id;
738 std::map<int,int>::iterator leaf_iterator;
739 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
741 int leaf=leaf_iterator->first;
742 int parent=leaf_iterator->second;
743 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
744 ArrayVector<Int32> indeces;
746 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
747 StackEntry_t stack_entry(indeces.begin(),
749 ext_param_.class_count_);
754 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
756 stack_entry.leftParent=parent;
760 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
761 stack_entry.rightParent=parent;
765 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
767 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
780 online_visitor_.deactivate();
783 template<
class LabelType,
class PreprocessorTag>
784 template<
class U,
class C1,
805 ext_param_.class_count_=0;
813 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
815 typename RF_CHOOSER(Stop_t)::type stop
816 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
818 typename RF_CHOOSER(Split_t)::type split
819 = RF_CHOOSER(Split_t)::choose(split_, default_split);
823 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
825 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
827 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
828 online_visitor_.activate();
831 RandFunctor_t randint ( random);
837 Preprocessor_t preprocessor( features, response,
838 options_, ext_param_);
841 split.set_external_parameters(ext_param_);
842 stop.set_external_parameters(ext_param_);
849 preprocessor.strata().end(),
850 detail::make_sampler_opt(options_)
851 .sampleSize(ext_param().actual_msample_),
858 first_stack_entry( sampler.sampledIndices().begin(),
859 sampler.sampledIndices().end(),
860 ext_param_.class_count_);
862 .set_oob_range( sampler.oobIndices().begin(),
863 sampler.oobIndices().end());
864 online_visitor_.reset_tree(treeId);
865 online_visitor_.tree_id=treeId;
866 trees_[treeId].reset();
868 .learn( preprocessor.features(),
869 preprocessor.response(),
876 .visit_after_tree( *
this,
882 online_visitor_.deactivate();
885 template <
class LabelType,
class PreprocessorTag>
886 template <
class U,
class C1,
898 Random_t
const & random)
909 vigra_precondition(features.
shape(0) == response.
shape(0),
910 "RandomForest::learn(): shape mismatch between features and response.");
917 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
919 typename RF_CHOOSER(Stop_t)::type stop
920 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
922 typename RF_CHOOSER(Split_t)::type split
923 = RF_CHOOSER(Split_t)::choose(split_, default_split);
927 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
929 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
931 if(options_.prepare_online_learning_)
932 online_visitor_.activate();
934 online_visitor_.deactivate();
938 RandFunctor_t randint ( random);
945 Preprocessor_t preprocessor( features, response,
946 options_, ext_param_);
949 split.set_external_parameters(ext_param_);
950 stop.set_external_parameters(ext_param_);
954 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
957 preprocessor.strata().end(),
958 detail::make_sampler_opt(options_)
959 .sampleSize(ext_param().actual_msample_),
962 visitor.visit_at_beginning(*
this, preprocessor);
965 for(
int ii = 0; ii < (int)trees_.size(); ++ii)
971 first_stack_entry( sampler.sampledIndices().begin(),
972 sampler.sampledIndices().end(),
973 ext_param_.class_count_);
975 .set_oob_range( sampler.oobIndices().begin(),
976 sampler.oobIndices().end());
978 .learn( preprocessor.features(),
979 preprocessor.response(),
986 .visit_after_tree( *
this,
993 visitor.visit_at_end(*
this, preprocessor);
995 online_visitor_.deactivate();
1001 template <
class LabelType,
class Tag>
1002 template <
class U,
class C,
class Stop>
1006 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1007 "RandomForestn::predictLabel():"
1008 " Too few columns in feature matrix.");
1009 vigra_precondition(
rowCount(features) == 1,
1010 "RandomForestn::predictLabel():"
1011 " Feature matrix must have a singlerow.");
1013 garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0);
1015 predictProbabilities(features, garbage_prediction_, stop);
1016 ext_param_.to_classlabel(
argMax(garbage_prediction_), d);
1022 template <
class LabelType,
class PreprocessorTag>
1023 template <
class U,
class C>
1028 using namespace functor;
1029 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1030 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1031 vigra_precondition(
rowCount(features) == 1,
1032 "RandomForestn::predictLabel():"
1033 " Feature matrix must have a single row.");
1034 Matrix<double> prob(1,ext_param_.class_count_);
1035 predictProbabilities(features, prob);
1036 std::transform( prob.begin(), prob.end(),
1037 priors.
begin(), prob.begin(),
1040 ext_param_.to_classlabel(
argMax(prob), d);
1044 template<
class LabelType,
class PreprocessorTag>
1045 template <
class T1,
class T2,
class C>
1054 "RandomFroest::predictProbabilities():"
1055 " Feature matrix and probability matrix size mismatch.");
1058 vigra_precondition(
columnCount(predictionSet.features) >= ext_param_.column_count_,
1059 "RandomForestn::predictProbabilities():"
1060 " Too few columns in feature matrix.");
1063 "RandomForestn::predictProbabilities():"
1064 " Probability matrix must have as many columns as there are classes.");
1067 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1070 for(
int k=0; k<options_.tree_count_; ++k)
1072 set_id=(set_id+1) % predictionSet.indices[0].size();
1073 typedef std::set<SampleRange<T1> > my_set;
1074 typedef typename my_set::iterator set_it;
1077 std::vector<std::pair<int,set_it> > stack;
1079 for(set_it i=predictionSet.ranges[set_id].begin();
1080 i!=predictionSet.ranges[set_id].end();++i)
1081 stack.push_back(std::pair<int,set_it>(2,i));
1083 int num_decisions=0;
1084 while(!stack.empty())
1086 set_it range=stack.back().second;
1087 int index=stack.back().first;
1091 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1094 trees_[k].parameters_,
1095 index).prob_begin();
1096 for(
int i=range->start;i!=range->end;++i)
1099 for(
int l=0; l<ext_param_.class_count_; ++l)
1101 prob(predictionSet.indices[set_id][i], l) += (T2)weights[l];
1103 totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l];
1110 if(trees_[k].topology_[index]!=i_ThresholdNode)
1112 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1114 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1115 if(range->min_boundaries[node.column()]>=node.threshold())
1118 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1121 if(range->max_boundaries[node.column()]<node.threshold())
1124 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1128 SampleRange<T1> new_range=*range;
1129 new_range.min_boundaries[node.column()]=FLT_MAX;
1130 range->max_boundaries[node.column()]=-FLT_MAX;
1131 new_range.start=new_range.end=range->end;
1133 while(i!=range->end)
1136 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1138 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1139 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1142 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1147 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1148 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1153 if(range->start==range->end)
1155 predictionSet.ranges[set_id].erase(range);
1159 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1162 if(new_range.start!=new_range.end)
1164 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1165 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1169 predictionSet.cumulativePredTime[k]=num_decisions;
1171 for(
unsigned int i=0;i<totalWeights.size();++i)
1175 for(
int l=0; l<ext_param_.class_count_; ++l)
1178 prob(i, l) /= totalWeights[i];
1180 assert(test==totalWeights[i]);
1181 assert(totalWeights[i]>0.0);
1185 template <
class LabelType,
class PreprocessorTag>
1186 template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1189 MultiArrayView<2, T, C2> & prob,
1190 Stop_t & stop_)
const
1196 "RandomForestn::predictProbabilities():"
1197 " Feature matrix and probability matrix size mismatch.");
1201 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1202 "RandomForestn::predictProbabilities():"
1203 " Too few columns in feature matrix.");
1206 "RandomForestn::predictProbabilities():"
1207 " Probability matrix must have as many columns as there are classes.");
1209 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1210 Default_Stop_t default_stop(options_);
1211 typename RF_CHOOSER(Stop_t)::type & stop
1212 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1214 stop.set_external_parameters(ext_param_, tree_count());
1215 prob.init(NumericTraits<T>::zero());
1225 for(
int row=0; row <
rowCount(features); ++row)
1227 ArrayVector<double>::const_iterator weights;
1230 double totalWeight = 0.0;
1233 for(
int k=0; k<options_.tree_count_; ++k)
1236 weights = trees_[k ].predict(
rowVector(features, row));
1239 int weighted = options_.predict_weighted_;
1240 for(
int l=0; l<ext_param_.class_count_; ++l)
1242 double cur_w = weights[l] * (weighted * (*(weights-1))
1244 prob(row, l) += (T)cur_w;
1246 totalWeight += cur_w;
1248 if(stop.after_prediction(weights,
1258 for(
int l=0; l< ext_param_.class_count_; ++l)
1260 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1266 template <
class LabelType,
class PreprocessorTag>
1267 template <
class U,
class C1,
class T,
class C2>
1268 void RandomForest<LabelType, PreprocessorTag>
1269 ::predictRaw(MultiArrayView<2, U, C1>
const & features,
1270 MultiArrayView<2, T, C2> & prob)
const
1276 "RandomForestn::predictProbabilities():"
1277 " Feature matrix and probability matrix size mismatch.");
1281 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1282 "RandomForestn::predictProbabilities():"
1283 " Too few columns in feature matrix.");
1286 "RandomForestn::predictProbabilities():"
1287 " Probability matrix must have as many columns as there are classes.");
1289 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1290 prob.init(NumericTraits<T>::zero());
1300 for(
int row=0; row <
rowCount(features); ++row)
1302 ArrayVector<double>::const_iterator weights;
1305 double totalWeight = 0.0;
1308 for(
int k=0; k<options_.tree_count_; ++k)
1311 weights = trees_[k ].predict(
rowVector(features, row));
1314 int weighted = options_.predict_weighted_;
1315 for(
int l=0; l<ext_param_.class_count_; ++l)
1317 double cur_w = weights[l] * (weighted * (*(weights-1))
1319 prob(row, l) += (T)cur_w;
1321 totalWeight += cur_w;
1325 prob/= options_.tree_count_;
1333 #include "random_forest/rf_algorithm.hxx"
1334 #endif // VIGRA_RANDOM_FOREST_HXX