mlpack  1.0.12
decision_stump.hpp
Go to the documentation of this file.
1 
14 #ifndef __MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
15 #define __MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
16 
17 #include <mlpack/core.hpp>
18 
19 namespace mlpack {
20 namespace decision_stump {
21 
35 template <typename MatType = arma::mat>
37 {
38  public:
48  DecisionStump(const MatType& data,
49  const arma::Row<size_t>& labels,
50  const size_t classes,
51  size_t inpBucketSize);
52 
61  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
62 
75  DecisionStump(const DecisionStump<>& other,
76  const MatType& data,
77  const arma::rowvec& weights,
78  const arma::Row<size_t>& labels);
79 
81  int SplitAttribute() const { return splitAttribute; }
83  int& SplitAttribute() { return splitAttribute; }
84 
86  const arma::vec& Split() const { return split; }
88  arma::vec& Split() { return split; }
89 
91  const arma::Col<size_t> BinLabels() const { return binLabels; }
93  arma::Col<size_t>& BinLabels() { return binLabels; }
94 
95  private:
97  size_t numClass;
98 
101 
103  size_t bucketSize;
104 
106  arma::vec split;
107 
109  arma::Col<size_t> binLabels;
110 
119  template <bool isWeight>
120  double SetupSplitAttribute(const arma::rowvec& attribute,
121  const arma::Row<size_t>& labels,
122  const arma::rowvec& weightD);
123 
131  template <typename rType> void TrainOnAtt(const arma::rowvec& attribute,
132  const arma::Row<size_t>& labels);
133 
138  void MergeRanges();
139 
146  template <typename rType> rType CountMostFreq(const arma::Row<rType>&
147  subCols);
148 
154  template <typename rType> int IsDistinct(const arma::Row<rType>& featureRow);
155 
163  template <typename LabelType, bool isWeight>
164  double CalculateEntropy(arma::subview_row<LabelType> labels, int begin,
165  const arma::rowvec& tempD);
166 
174  template <bool isWeight>
175  void Train(const MatType& data, const arma::Row<size_t>& labels,
176  const arma::rowvec& weightD);
177 
178 };
179 
180 }; // namespace decision_stump
181 }; // namespace mlpack
182 
183 #include "decision_stump_impl.hpp"
184 
185 #endif
void MergeRanges()
After the "split" matrix has been set up, merge ranges with identical class labels.
void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
int splitAttribute
Stores the value of the attribute on which to split.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:23
arma::Col< size_t > & BinLabels()
Modify the labels for each split bin (be careful!).
void Train(const MatType &data, const arma::Row< size_t > &labels, const arma::rowvec &weightD)
Train the decision stump on the given data and labels.
size_t numClass
Stores the number of classes.
double SetupSplitAttribute(const arma::rowvec &attribute, const arma::Row< size_t > &labels, const arma::rowvec &weightD)
Sets up attribute as if it were splitting on it and finds entropy when splitting on attribute...
This class implements a decision stump.
const arma::vec & Split() const
Access the splitting values.
arma::Col< size_t > binLabels
Stores the labels for each splitting bin.
size_t bucketSize
Size of bucket while determining splitting criterion.
void TrainOnAtt(const arma::rowvec &attribute, const arma::Row< size_t > &labels)
After having decided the attribute on which to split, train on that attribute.
int SplitAttribute() const
Access the splitting attribute.
rType CountMostFreq(const arma::Row< rType > &subCols)
Count the most frequently occurring element in subCols.
int IsDistinct(const arma::Row< rType > &featureRow)
Returns 1 if all the values of featureRow are not same.
arma::vec & Split()
Modify the splitting values (be careful!).
double CalculateEntropy(arma::subview_row< LabelType > labels, int begin, const arma::rowvec &tempD)
Calculate the entropy of the given attribute.
int & SplitAttribute()
Modify the splitting attribute (be careful!).
arma::vec split
Stores the splitting values after training.
DecisionStump(const MatType &data, const arma::Row< size_t > &labels, const size_t classes, size_t inpBucketSize)
Constructor.
const arma::Col< size_t > BinLabels() const
Access the labels for each split bin.