[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_earlystopping.hxx
1 #ifndef RF_EARLY_STOPPING_P_HXX
2 #define RF_EARLY_STOPPING_P_HXX
3 #include <cmath>
4 #include "rf_common.hxx"
5 
6 namespace vigra
7 {
8 
9 #if 0
10 namespace es_detail
11 {
12  template<class T>
13  T power(T const & in, int n)
14  {
15  T result = NumericTraits<T>::one();
16  for(int ii = 0; ii < n ;++ii)
17  result *= in;
18  return result;
19  }
20 }
21 #endif
22 
23 /**Base class from which all EarlyStopping Functors derive.
24  */
25 class StopBase
26 {
27 protected:
28  ProblemSpec<> ext_param_;
29  int tree_count_ ;
30  bool is_weighted_;
31 
32 public:
33  template<class T>
34  void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
35  {
36  ext_param_ = prob;
37  is_weighted_ = is_weighted;
38  tree_count_ = tree_count;
39  }
40 
41 #ifdef DOXYGEN
42  /** called after the prediction of a tree was added to the total prediction
43  * \param weightIter Iterator to the weights delivered by current tree.
44  * \param k after kth tree
45  * \param prob Total probability array
46  * \param totalCt sum of probability array.
47  */
48  template<class WeightIter, class T, class C>
49  bool after_prediction(WeightIter weightIter, int k, MultiArrayView<2, T, C> const & prob , double totalCt)
50 #else
51  template<class WeightIter, class T, class C>
52  bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
53  {return false;}
54 #endif //DOXYGEN
55 };
56 
57 
58 /**Stop predicting after a set number of trees
59  */
60 class StopAfterTree : public StopBase
61 {
62 public:
63  double max_tree_p;
64  int max_tree_;
65  typedef StopBase SB;
66 
67  ArrayVector<double> depths;
68 
69  /** Constructor
70  * \param max_tree number of trees to be used for prediction
71  */
72  StopAfterTree(double max_tree)
73  :
74  max_tree_p(max_tree)
75  {}
76 
77  template<class T>
78  void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
79  {
80  max_tree_ = ceil(max_tree_p * tree_count);
81  SB::set_external_parameters(prob, tree_count, is_weighted);
82  }
83 
84  template<class WeightIter, class T, class C>
85  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
86  {
87  if(k == SB::tree_count_ -1)
88  {
89  depths.push_back(double(k+1)/double(SB::tree_count_));
90  return false;
91  }
92  if(k < max_tree_)
93  return false;
94  depths.push_back(double(k+1)/double(SB::tree_count_));
95  return true;
96  }
97 };
98 
99 /** Stop predicting after a certain amount of votes exceed certain proportion.
100  * case unweighted voting: stop if the leading class exceeds proportion * SB::tree_count_
101  * case weighted voting: stop if the leading class exceeds proportion * msample_ * SB::tree_count_ ;
102  * (maximal number of votes possible in both cases)
103  */
105 {
106 public:
107  double proportion_;
108  typedef StopBase SB;
109  ArrayVector<double> depths;
110 
111  /** Constructor
112  * \param proportion specify proportion to be used.
113  */
114  StopAfterVoteCount(double proportion)
115  :
116  proportion_(proportion)
117  {}
118 
119  template<class WeightIter, class T, class C>
120  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & prob, double /* totalCt */)
121  {
122  if(k == SB::tree_count_ -1)
123  {
124  depths.push_back(double(k+1)/double(SB::tree_count_));
125  return false;
126  }
127 
128 
129  if(SB::is_weighted_)
130  {
131  if(prob[argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_)
132  {
133  depths.push_back(double(k+1)/double(SB::tree_count_));
134  return true;
135  }
136  }
137  else
138  {
139  if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
140  {
141  depths.push_back(double(k+1)/double(SB::tree_count_));
142  return true;
143  }
144  }
145  return false;
146  }
147 
148 };
149 
150 
151 /** Stop predicting if the 2norm of the probabilities does not change*/
153 
154 {
155 public:
156  double thresh_;
157  int num_;
158  MultiArray<2, double> last_;
160  ArrayVector<double> depths;
161  typedef StopBase SB;
162 
163  /** Constructor
164  * \param thresh: If the two norm of the probabilities changes less then thresh then stop
165  * \param num : look at atleast num trees before stopping
166  */
167  StopIfConverging(double thresh, int num = 10)
168  :
169  thresh_(thresh),
170  num_(num)
171  {}
172 
173  template<class T>
174  void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
175  {
176  last_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
177  cur_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
178  SB::set_external_parameters(prob, tree_count, is_weighted);
179  }
180  template<class WeightIter, class T, class C>
181  bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> const & prob, double totalCt)
182  {
183  if(k == SB::tree_count_ -1)
184  {
185  depths.push_back(double(k+1)/double(SB::tree_count_));
186  return false;
187  }
188  if(k <= num_)
189  {
190  last_ = prob;
191  last_/= last_.norm(1);
192  return false;
193  }
194  else
195  {
196  cur_ = prob;
197  cur_ /= cur_.norm(1);
198  last_ -= cur_;
199  double nrm = last_.norm();
200  if(nrm < thresh_)
201  {
202  depths.push_back(double(k+1)/double(SB::tree_count_));
203  return true;
204  }
205  else
206  {
207  last_ = cur_;
208  }
209  }
210  return false;
211  }
212 };
213 
214 
215 /** Stop predicting if the margin prob(leading class) - prob(second class) exceeds a proportion
216  * case unweighted voting: stop if margin exceeds proportion * SB::tree_count_
217  * case weighted voting: stop if margin exceeds proportion * msample_ * SB::tree_count_ ;
218  * (maximal number of votes possible in both cases)
219  */
220 class StopIfMargin : public StopBase
221 {
222 public:
223  double proportion_;
224  typedef StopBase SB;
225  ArrayVector<double> depths;
226 
227  /** Constructor
228  * \param proportion specify proportion to be used.
229  */
230  StopIfMargin(double proportion)
231  :
232  proportion_(proportion)
233  {}
234 
235  template<class WeightIter, class T, class C>
236  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> prob, double /* totalCt */)
237  {
238  if(k == SB::tree_count_ -1)
239  {
240  depths.push_back(double(k+1)/double(SB::tree_count_));
241  return false;
242  }
243  int index = argMax(prob);
244  double a = prob[argMax(prob)];
245  prob[argMax(prob)] = 0;
246  double b = prob[argMax(prob)];
247  prob[index] = a;
248  double margin = a - b;
249  if(SB::is_weighted_)
250  {
251  if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_)
252  {
253  depths.push_back(double(k+1)/double(SB::tree_count_));
254  return true;
255  }
256  }
257  else
258  {
259  if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
260  {
261  depths.push_back(double(k+1)/double(SB::tree_count_));
262  return true;
263  }
264  }
265  return false;
266  }
267 };
268 
269 
270 /**Probabilistic Stopping criterion (binomial test)
271  *
272  * Can only be used in a two class setting
273  *
274  * Stop if the Parameters estimated for the underlying binomial distribution
275  * can be estimated with certainty over 1-alpha.
276  * (Thesis, Rahul Nair Page 80 onwards: called the "binomial" criterion
277  */
278 class StopIfBinTest : public StopBase
279 {
280 public:
281  double alpha_;
282  MultiArrayView<2, double> n_choose_k;
283  /** Constructor
284  * \param alpha specify alpha (=proportion) value for binomial test.
285  * \param nck_ Matrix with precomputed values for n choose k
286  * nck_(n, k) is n choose k.
287  */
289  :
290  alpha_(alpha),
291  n_choose_k(nck_)
292  {}
293  typedef StopBase SB;
294 
295  /**ArrayVector that will contain the fraction of trees that was visited before terminating
296  */
298 
299  double binomial(int N, int k, double p)
300  {
301 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
302  return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
303  }
304 
305  template<class WeightIter, class T, class C>
306  bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> prob, double totalCt)
307  {
308  if(k == SB::tree_count_ -1)
309  {
310  depths.push_back(double(k+1)/double(SB::tree_count_));
311  return false;
312  }
313  if(k < 10)
314  {
315  return false;
316  }
317  int index = argMax(prob);
318  int n_a = prob[index];
319  int n_b = prob[(index+1)%2];
320  int n_tilde = (SB::tree_count_ - n_a + n_b);
321  double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde);
322  vigra_precondition(p_a <= 1, "probability should be smaller than 1");
323  double cum_val = 0;
324  int c = 0;
325  // std::cerr << "prob: " << p_a << std::endl;
326  if(n_a <= 0)n_a = 0;
327  if(n_b <= 0)n_b = 0;
328  for(int ii = 0; ii <= n_b + n_a;++ii)
329  {
330 // std::cerr << "nb +ba " << n_b + n_a << " " << ii <<std::endl;
331  cum_val += binomial(n_b + n_a, ii, p_a);
332  if(cum_val >= 1 -alpha_)
333  {
334  c = ii;
335  break;
336  }
337  }
338 // std::cerr << c << " " << n_a << " " << n_b << " " << p_a << alpha_ << std::endl;
339  if(c < n_a)
340  {
341  depths.push_back(double(k+1)/double(SB::tree_count_));
342  return true;
343  }
344 
345  return false;
346  }
347 };
348 
349 /**Probabilistic Stopping criteria. (toChange)
350  *
351  * Can only be used in a two class setting
352  *
353  * Stop if the probability that the decision will change after seeing all trees falls under
354  * a specified value alpha.
355  * (Thesis, Rahul Nair Page 80 onwards: called the "toChange" criterion
356  */
357 class StopIfProb : public StopBase
358 {
359 public:
360  double alpha_;
361  MultiArrayView<2, double> n_choose_k;
362 
363 
364  /** Constructor
365  * \param alpha specify alpha (=proportion) value
366  * \param nck_ Matrix with precomputed values for n choose k
367  * nck_(n, k) is n choose k.
368  */
370  :
371  alpha_(alpha),
372  n_choose_k(nck_)
373  {}
374  typedef StopBase SB;
375  /**ArrayVector that will contain the fraction of trees that was visited before terminating
376  */
378 
379  double binomial(int N, int k, double p)
380  {
381 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
382  return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
383  }
384 
385  template<class WeightIter, class T, class C>
386  bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> prob, double totalCt)
387  {
388  if(k == SB::tree_count_ -1)
389  {
390  depths.push_back(double(k+1)/double(SB::tree_count_));
391  return false;
392  }
393  if(k <= 10)
394  {
395  return false;
396  }
397  int index = argMax(prob);
398  int n_a = prob[index];
399  int n_b = prob[(index+1)%2];
400  int n_needed = ceil(double(SB::tree_count_)/2.0)-n_a;
401  int n_tilde = SB::tree_count_ - (n_a +n_b);
402  if(n_tilde <= 0) n_tilde = 0;
403  if(n_needed <= 0) n_needed = 0;
404  double p = 0;
405  for(int ii = n_needed; ii < n_tilde; ++ii)
406  p += binomial(n_tilde, ii, 0.5);
407 
408  if(p >= 1-alpha_)
409  {
410  depths.push_back(double(k+1)/double(SB::tree_count_));
411  return true;
412  }
413 
414  return false;
415  }
416 };
417 } //namespace vigra;
418 #endif //RF_EARLY_STOPPING_P_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Sat Oct 5 2013)