[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_visitors.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 #ifndef RF_VISITORS_HXX
00036 #define RF_VISITORS_HXX
00037 
00038 #ifdef HasHDF5
00039 # include "vigra/hdf5impex.hxx"
00040 #endif // HasHDF5
00041 #include <vigra/windows.h>
00042 #include <iostream>
00043 #include <iomanip>
00044 #include <vigra/timing.hxx>
00045 
00046 namespace vigra
00047 {
00048 namespace rf
00049 {
00050 /** \addtogroup MachineLearning Machine Learning
00051 **/
00052 //@{
00053 
00054 /**
00055     This namespace contains all classes and methods related to extracting information during 
00056     learning of the random forest. All Visitors share the same interface defined in 
00057     visitors::VisitorBase. The member methods are invoked at certain points of the main code in 
00058     the order they were supplied.
00059     
00060     For the Random Forest the  Visitor concept is implemented as a statically linked list 
00061     (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The 
00062     VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
00063     
00064     To simplify usage create_visitor() factory methods are supplied.
00065     Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
00066     It is possible to supply more than one visitor. They will then be invoked in serial order.
00067 
00068     The calculated information are stored as public data members of the class. - see documentation
00069     of the individual visitors
00070     
00071     While creating a new visitor the new class should therefore publicly inherit from this class 
00072     (i.e.: see visitors::OOB_Error).
00073 
00074     \code
00075 
00076       typedef xxx feature_t \\ replace xxx with whichever type
00077       typedef yyy label_t   \\ meme chose. 
00078       MultiArrayView<2, feature_t> f = get_some_features();
00079       MultiArrayView<2, label_t>   l = get_some_labels();
00080       RandomForest<> rf()
00081     
00082       //calculate OOB Error
00083       visitors::OOB_Error oob_v;
00084       //calculate Variable Importance
00085       visitors::VariableImportanceVisitor varimp_v;
00086 
00087       double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
00088       //the data can be found in the attributes of oob_v and varimp_v now
00089       
00090     \endcode
00091 */
00092 namespace visitors
00093 {
00094     
00095     
00096 /** Base Class from which all Visitors derive. Can be used as a template to create new 
00097  * Visitors.
00098  */
00099 class VisitorBase
00100 {
00101     public:
00102     bool active_;   
00103     bool is_active()
00104     {
00105         return active_;
00106     }
00107 
00108     bool has_value()
00109     {
00110         return false;
00111     }
00112 
00113     VisitorBase()
00114         : active_(true)
00115     {}
00116 
00117     void deactivate()
00118     {
00119         active_ = false;
00120     }
00121     void activate()
00122     {
00123         active_ = true;
00124     }
00125     
00126     /** do something after the the Split has decided how to process the Region
00127      * (Stack entry)
00128      *
00129      * \param tree      reference to the tree that is currently being learned
00130      * \param split     reference to the split object
00131      * \param parent    current stack entry  which was used to decide the split
00132      * \param leftChild left stack entry that will be pushed
00133      * \param rightChild
00134      *                  right stack entry that will be pushed.
00135      * \param features  features matrix
00136      * \param labels    label matrix
00137      * \sa RF_Traits::StackEntry_t
00138      */
00139     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00140     void visit_after_split( Tree          & tree, 
00141                             Split         & split,
00142                             Region        & parent,
00143                             Region        & leftChild,
00144                             Region        & rightChild,
00145                             Feature_t     & features,
00146                             Label_t       & labels)
00147     {}
00148     
00149     /** do something after each tree has been learned
00150      *
00151      * \param rf        reference to the random forest object that called this
00152      *                  visitor
00153      * \param pr        reference to the preprocessor that processed the input
00154      * \param sm        reference to the sampler object
00155      * \param st        reference to the first stack entry
00156      * \param index     index of current tree
00157      */
00158     template<class RF, class PR, class SM, class ST>
00159     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00160     {}
00161     
00162     /** do something after all trees have been learned
00163      *
00164      * \param rf        reference to the random forest object that called this
00165      *                  visitor
00166      * \param pr        reference to the preprocessor that processed the input
00167      */
00168     template<class RF, class PR>
00169     void visit_at_end(RF const & rf, PR const & pr)
00170     {}
00171     
00172     /** do something before learning starts 
00173      *
00174      * \param rf        reference to the random forest object that called this
00175      *                  visitor
00176      * \param pr        reference to the Processor class used.
00177      */
00178     template<class RF, class PR>
00179     void visit_at_beginning(RF const & rf, PR const & pr)
00180     {}
00181     /** do some thing while traversing tree after it has been learned 
00182      *  (external nodes)
00183      *
00184      * \param tr        reference to the tree object that called this visitor
00185      * \param index     index in the topology_ array we currently are at
00186      * \param node_t    type of node we have (will be e_.... - )
00187      * \param weight    Node weight of current node. 
00188      * \sa  NodeTags;
00189      *
00190      * you can create the node by using a switch on node_tag and using the 
00191      * corresponding Node objects. Or - if you do not care about the type 
00192      * use the Nodebase class.
00193      */
00194     template<class TR, class IntT, class TopT,class Feat>
00195     void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features)
00196     {}
00197     
00198     /** do something when visiting a internal node after it has been learned
00199      *
00200      * \sa visit_external_node
00201      */
00202     template<class TR, class IntT, class TopT,class Feat>
00203     void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
00204     {}
00205 
00206     /** return a double value.  The value of the first 
00207      * visitor encountered that has a return value is returned with the
00208      * RandomForest::learn() method - or -1.0 if no return value visitor
00209      * existed. This functionality basically only exists so that the 
00210      * OOB - visitor can return the oob error rate like in the old version 
00211      * of the random forest.
00212      */
00213     double return_val()
00214     {
00215         return -1.0;
00216     }
00217 };
00218 
00219 
00220 /** Last Visitor that should be called to stop the recursion.
00221  */
00222 class StopVisiting: public VisitorBase
00223 {
00224     public:
00225     bool has_value()
00226     {
00227         return true;
00228     }
00229     double return_val()
00230     {
00231         return -1.0;
00232     }
00233 };
00234 namespace detail
00235 {
00236 /** Container elements of the statically linked Visitor list.
00237  *
00238  * use the create_visitor() factory functions to create visitors up to size 10;
00239  *
00240  */
00241 template <class Visitor, class Next = StopVisiting>
00242 class VisitorNode
00243 {
00244     public:
00245     
00246     StopVisiting    stop_;
00247     Next            next_;
00248     Visitor &       visitor_;   
00249     VisitorNode(Visitor & visitor, Next & next) 
00250     : 
00251         next_(next), visitor_(visitor)
00252     {}
00253 
00254     VisitorNode(Visitor &  visitor) 
00255     : 
00256         next_(stop_), visitor_(visitor)
00257     {}
00258 
00259     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00260     void visit_after_split( Tree          & tree, 
00261                             Split         & split,
00262                             Region        & parent,
00263                             Region        & leftChild,
00264                             Region        & rightChild,
00265                             Feature_t     & features,
00266                             Label_t       & labels)
00267     {
00268         if(visitor_.is_active())
00269             visitor_.visit_after_split(tree, split, 
00270                                        parent, leftChild, rightChild,
00271                                        features, labels);
00272         next_.visit_after_split(tree, split, parent, leftChild, rightChild,
00273                                 features, labels);
00274     }
00275 
00276     template<class RF, class PR, class SM, class ST>
00277     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00278     {
00279         if(visitor_.is_active())
00280             visitor_.visit_after_tree(rf, pr, sm, st, index);
00281         next_.visit_after_tree(rf, pr, sm, st, index);
00282     }
00283 
00284     template<class RF, class PR>
00285     void visit_at_beginning(RF & rf, PR & pr)
00286     {
00287         if(visitor_.is_active())
00288             visitor_.visit_at_beginning(rf, pr);
00289         next_.visit_at_beginning(rf, pr);
00290     }
00291     template<class RF, class PR>
00292     void visit_at_end(RF & rf, PR & pr)
00293     {
00294         if(visitor_.is_active())
00295             visitor_.visit_at_end(rf, pr);
00296         next_.visit_at_end(rf, pr);
00297     }
00298     
00299     template<class TR, class IntT, class TopT,class Feat>
00300     void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
00301     {
00302         if(visitor_.is_active())
00303             visitor_.visit_external_node(tr, index, node_t,features);
00304         next_.visit_external_node(tr, index, node_t,features);
00305     }
00306     template<class TR, class IntT, class TopT,class Feat>
00307     void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
00308     {
00309         if(visitor_.is_active())
00310             visitor_.visit_internal_node(tr, index, node_t,features);
00311         next_.visit_internal_node(tr, index, node_t,features);
00312     }
00313 
00314     double return_val()
00315     {
00316         if(visitor_.is_active() && visitor_.has_value())
00317             return visitor_.return_val();
00318         return next_.return_val();
00319     }
00320 };
00321 
00322 } //namespace detail
00323 
00324 //////////////////////////////////////////////////////////////////////////////
00325 //  Visitor Factory function up to 10 visitors                              //
00326 //////////////////////////////////////////////////////////////////////////////
00327 
00328 /** factory method to to be used with RandomForest::learn()
00329  */
00330 template<class A>
00331 detail::VisitorNode<A>
00332 create_visitor(A & a)
00333 {
00334    typedef detail::VisitorNode<A> _0_t;
00335    _0_t _0(a);
00336    return _0;
00337 }
00338 
00339 
00340 /** factory method to to be used with RandomForest::learn()
00341  */
00342 template<class A, class B>
00343 detail::VisitorNode<A, detail::VisitorNode<B> >
00344 create_visitor(A & a, B & b)
00345 {
00346    typedef detail::VisitorNode<B> _1_t;
00347    _1_t _1(b);
00348    typedef detail::VisitorNode<A, _1_t> _0_t;
00349    _0_t _0(a, _1);
00350    return _0;
00351 }
00352 
00353 
00354 /** factory method to to be used with RandomForest::learn()
00355  */
00356 template<class A, class B, class C>
00357 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
00358 create_visitor(A & a, B & b, C & c)
00359 {
00360    typedef detail::VisitorNode<C> _2_t;
00361    _2_t _2(c);
00362    typedef detail::VisitorNode<B, _2_t> _1_t;
00363    _1_t _1(b, _2);
00364    typedef detail::VisitorNode<A, _1_t> _0_t;
00365    _0_t _0(a, _1);
00366    return _0;
00367 }
00368 
00369 
00370 /** factory method to to be used with RandomForest::learn()
00371  */
00372 template<class A, class B, class C, class D>
00373 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00374     detail::VisitorNode<D> > > >
00375 create_visitor(A & a, B & b, C & c, D & d)
00376 {
00377    typedef detail::VisitorNode<D> _3_t;
00378    _3_t _3(d);
00379    typedef detail::VisitorNode<C, _3_t> _2_t;
00380    _2_t _2(c, _3);
00381    typedef detail::VisitorNode<B, _2_t> _1_t;
00382    _1_t _1(b, _2);
00383    typedef detail::VisitorNode<A, _1_t> _0_t;
00384    _0_t _0(a, _1);
00385    return _0;
00386 }
00387 
00388 
00389 /** factory method to to be used with RandomForest::learn()
00390  */
00391 template<class A, class B, class C, class D, class E>
00392 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00393     detail::VisitorNode<D, detail::VisitorNode<E> > > > >
00394 create_visitor(A & a, B & b, C & c, 
00395                D & d, E & e)
00396 {
00397    typedef detail::VisitorNode<E> _4_t;
00398    _4_t _4(e);
00399    typedef detail::VisitorNode<D, _4_t> _3_t;
00400    _3_t _3(d, _4);
00401    typedef detail::VisitorNode<C, _3_t> _2_t;
00402    _2_t _2(c, _3);
00403    typedef detail::VisitorNode<B, _2_t> _1_t;
00404    _1_t _1(b, _2);
00405    typedef detail::VisitorNode<A, _1_t> _0_t;
00406    _0_t _0(a, _1);
00407    return _0;
00408 }
00409 
00410 
00411 /** factory method to to be used with RandomForest::learn()
00412  */
00413 template<class A, class B, class C, class D, class E,
00414          class F>
00415 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00416     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
00417 create_visitor(A & a, B & b, C & c, 
00418                D & d, E & e, F & f)
00419 {
00420    typedef detail::VisitorNode<F> _5_t;
00421    _5_t _5(f);
00422    typedef detail::VisitorNode<E, _5_t> _4_t;
00423    _4_t _4(e, _5);
00424    typedef detail::VisitorNode<D, _4_t> _3_t;
00425    _3_t _3(d, _4);
00426    typedef detail::VisitorNode<C, _3_t> _2_t;
00427    _2_t _2(c, _3);
00428    typedef detail::VisitorNode<B, _2_t> _1_t;
00429    _1_t _1(b, _2);
00430    typedef detail::VisitorNode<A, _1_t> _0_t;
00431    _0_t _0(a, _1);
00432    return _0;
00433 }
00434 
00435 
00436 /** factory method to to be used with RandomForest::learn()
00437  */
00438 template<class A, class B, class C, class D, class E,
00439          class F, class G>
00440 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00441     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00442     detail::VisitorNode<G> > > > > > >
00443 create_visitor(A & a, B & b, C & c, 
00444                D & d, E & e, F & f, G & g)
00445 {
00446    typedef detail::VisitorNode<G> _6_t;
00447    _6_t _6(g);
00448    typedef detail::VisitorNode<F, _6_t> _5_t;
00449    _5_t _5(f, _6);
00450    typedef detail::VisitorNode<E, _5_t> _4_t;
00451    _4_t _4(e, _5);
00452    typedef detail::VisitorNode<D, _4_t> _3_t;
00453    _3_t _3(d, _4);
00454    typedef detail::VisitorNode<C, _3_t> _2_t;
00455    _2_t _2(c, _3);
00456    typedef detail::VisitorNode<B, _2_t> _1_t;
00457    _1_t _1(b, _2);
00458    typedef detail::VisitorNode<A, _1_t> _0_t;
00459    _0_t _0(a, _1);
00460    return _0;
00461 }
00462 
00463 
00464 /** factory method to to be used with RandomForest::learn()
00465  */
00466 template<class A, class B, class C, class D, class E,
00467          class F, class G, class H>
00468 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00469     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00470     detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
00471 create_visitor(A & a, B & b, C & c, 
00472                D & d, E & e, F & f, 
00473                G & g, H & h)
00474 {
00475    typedef detail::VisitorNode<H> _7_t;
00476    _7_t _7(h);
00477    typedef detail::VisitorNode<G, _7_t> _6_t;
00478    _6_t _6(g, _7);
00479    typedef detail::VisitorNode<F, _6_t> _5_t;
00480    _5_t _5(f, _6);
00481    typedef detail::VisitorNode<E, _5_t> _4_t;
00482    _4_t _4(e, _5);
00483    typedef detail::VisitorNode<D, _4_t> _3_t;
00484    _3_t _3(d, _4);
00485    typedef detail::VisitorNode<C, _3_t> _2_t;
00486    _2_t _2(c, _3);
00487    typedef detail::VisitorNode<B, _2_t> _1_t;
00488    _1_t _1(b, _2);
00489    typedef detail::VisitorNode<A, _1_t> _0_t;
00490    _0_t _0(a, _1);
00491    return _0;
00492 }
00493 
00494 
00495 /** factory method to to be used with RandomForest::learn()
00496  */
00497 template<class A, class B, class C, class D, class E,
00498          class F, class G, class H, class I>
00499 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00500     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00501     detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
00502 create_visitor(A & a, B & b, C & c, 
00503                D & d, E & e, F & f, 
00504                G & g, H & h, I & i)
00505 {
00506    typedef detail::VisitorNode<I> _8_t;
00507    _8_t _8(i);
00508    typedef detail::VisitorNode<H, _8_t> _7_t;
00509    _7_t _7(h, _8);
00510    typedef detail::VisitorNode<G, _7_t> _6_t;
00511    _6_t _6(g, _7);
00512    typedef detail::VisitorNode<F, _6_t> _5_t;
00513    _5_t _5(f, _6);
00514    typedef detail::VisitorNode<E, _5_t> _4_t;
00515    _4_t _4(e, _5);
00516    typedef detail::VisitorNode<D, _4_t> _3_t;
00517    _3_t _3(d, _4);
00518    typedef detail::VisitorNode<C, _3_t> _2_t;
00519    _2_t _2(c, _3);
00520    typedef detail::VisitorNode<B, _2_t> _1_t;
00521    _1_t _1(b, _2);
00522    typedef detail::VisitorNode<A, _1_t> _0_t;
00523    _0_t _0(a, _1);
00524    return _0;
00525 }
00526 
00527 /** factory method to to be used with RandomForest::learn()
00528  */
00529 template<class A, class B, class C, class D, class E,
00530          class F, class G, class H, class I, class J>
00531 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00532     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00533     detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
00534     detail::VisitorNode<J> > > > > > > > > >
00535 create_visitor(A & a, B & b, C & c, 
00536                D & d, E & e, F & f, 
00537                G & g, H & h, I & i,
00538                J & j)
00539 {
00540    typedef detail::VisitorNode<J> _9_t;
00541    _9_t _9(j);
00542    typedef detail::VisitorNode<I, _9_t> _8_t;
00543    _8_t _8(i, _9);
00544    typedef detail::VisitorNode<H, _8_t> _7_t;
00545    _7_t _7(h, _8);
00546    typedef detail::VisitorNode<G, _7_t> _6_t;
00547    _6_t _6(g, _7);
00548    typedef detail::VisitorNode<F, _6_t> _5_t;
00549    _5_t _5(f, _6);
00550    typedef detail::VisitorNode<E, _5_t> _4_t;
00551    _4_t _4(e, _5);
00552    typedef detail::VisitorNode<D, _4_t> _3_t;
00553    _3_t _3(d, _4);
00554    typedef detail::VisitorNode<C, _3_t> _2_t;
00555    _2_t _2(c, _3);
00556    typedef detail::VisitorNode<B, _2_t> _1_t;
00557    _1_t _1(b, _2);
00558    typedef detail::VisitorNode<A, _1_t> _0_t;
00559    _0_t _0(a, _1);
00560    return _0;
00561 }
00562 
00563 //////////////////////////////////////////////////////////////////////////////
00564 // Visitors of communal interest.                                           //
00565 //////////////////////////////////////////////////////////////////////////////
00566 
00567 
00568 /** Visitor to gain information, later needed for online learning.
00569  */
00570 
00571 class OnlineLearnVisitor: public VisitorBase
00572 {
00573 public:
00574     //Set if we adjust thresholds
00575     bool adjust_thresholds;
00576     //Current tree id
00577     int tree_id;
00578     //Last node id for finding parent
00579     int last_node_id;
00580     //Need to now the label for interior node visiting
00581     vigra::Int32 current_label;
00582     //marginal distribution for interior nodes
00583     struct MarginalDistribution
00584     {
00585         ArrayVector<Int32> leftCounts;
00586         Int32 leftTotalCounts;
00587         ArrayVector<Int32> rightCounts;
00588         Int32 rightTotalCounts;
00589         double gap_left;
00590         double gap_right;
00591     };
00592     typedef ArrayVector<vigra::Int32> IndexList;
00593 
00594     //All information for one tree
00595     struct TreeOnlineInformation
00596     {
00597         std::vector<MarginalDistribution> mag_distributions;
00598         std::vector<IndexList> index_lists;
00599         //map for linear index of mag_distiributions
00600         std::map<int,int> interior_to_index;
00601         //map for linear index of index_lists
00602         std::map<int,int> exterior_to_index;
00603     };
00604 
00605     //All trees
00606     std::vector<TreeOnlineInformation> trees_online_information;
00607 
00608     /** Initilize, set the number of trees
00609      */
00610     template<class RF,class PR>
00611     void visit_at_beginning(RF & rf,const PR & pr)
00612     {
00613         tree_id=0;
00614         trees_online_information.resize(rf.options_.tree_count_);
00615     }
00616 
00617     /** Reset a tree
00618      */
00619     void reset_tree(int tree_id)
00620     {
00621         trees_online_information[tree_id].mag_distributions.clear();
00622         trees_online_information[tree_id].index_lists.clear();
00623         trees_online_information[tree_id].interior_to_index.clear();
00624         trees_online_information[tree_id].exterior_to_index.clear();
00625     }
00626 
00627     /** simply increase the tree count
00628     */
00629     template<class RF, class PR, class SM, class ST>
00630     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00631     {
00632         tree_id++;
00633     }
00634     
00635     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00636     void visit_after_split( Tree          & tree, 
00637                 Split         & split,
00638                             Region       & parent,
00639                             Region        & leftChild,
00640                             Region        & rightChild,
00641                             Feature_t     & features,
00642                             Label_t       & labels)
00643     {
00644         int linear_index;
00645         int addr=tree.topology_.size();
00646         if(split.createNode().typeID() == i_ThresholdNode)
00647         {
00648             if(adjust_thresholds)
00649             {
00650                 //Store marginal distribution
00651                 linear_index=trees_online_information[tree_id].mag_distributions.size();
00652                 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
00653                 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
00654 
00655                 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
00656                 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
00657 
00658                 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
00659                 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
00660                 //Store the gap
00661                 double gap_left,gap_right;
00662                 int i;
00663                 gap_left=features(leftChild[0],split.bestSplitColumn());
00664                 for(i=1;i<leftChild.size();++i)
00665                     if(features(leftChild[i],split.bestSplitColumn())>gap_left)
00666                         gap_left=features(leftChild[i],split.bestSplitColumn());
00667                 gap_right=features(rightChild[0],split.bestSplitColumn());
00668                 for(i=1;i<rightChild.size();++i)
00669                     if(features(rightChild[i],split.bestSplitColumn())<gap_right)
00670                         gap_right=features(rightChild[i],split.bestSplitColumn());
00671                 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
00672                 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
00673             }
00674         }
00675         else
00676         {
00677             //Store index list
00678             linear_index=trees_online_information[tree_id].index_lists.size();
00679             trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
00680 
00681             trees_online_information[tree_id].index_lists.push_back(IndexList());
00682 
00683             trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
00684             std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
00685         }
00686     }
00687     void add_to_index_list(int tree,int node,int index)
00688     {
00689         if(!this->active_)
00690             return;
00691         TreeOnlineInformation &ti=trees_online_information[tree];
00692         ti.index_lists[ti.exterior_to_index[node]].push_back(index);
00693     }
00694     void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
00695     {
00696         if(!this->active_)
00697             return;
00698         trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
00699         trees_online_information[src_tree].exterior_to_index.erase(src_index);
00700     }
00701     /** do something when visiting a internal node during getToLeaf
00702      *
00703      * remember as last node id, for finding the parent of the last external node
00704      * also: adjust class counts and borders
00705      */
00706     template<class TR, class IntT, class TopT,class Feat>
00707         void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
00708         {
00709             last_node_id=index;
00710             if(adjust_thresholds)
00711             {
00712                 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
00713                 //Check if we are in the gap
00714                 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
00715                 TreeOnlineInformation &ti=trees_online_information[tree_id];
00716                 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
00717                 if(value>m.gap_left && value<m.gap_right)
00718                 {
00719                     //Check which site we want to go
00720                     if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
00721                     {
00722                         //We want to go left
00723                         m.gap_left=value;
00724                     }
00725                     else
00726                     {
00727                         //We want to go right
00728                         m.gap_right=value;
00729                     }
00730                     Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
00731                 }
00732                 //Adjust class counts
00733                 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
00734                 {
00735                     ++m.rightTotalCounts;
00736                     ++m.rightCounts[current_label];
00737                 }
00738                 else
00739                 {
00740                     ++m.leftTotalCounts;
00741                     ++m.rightCounts[current_label];
00742                 }
00743             }
00744         }
00745     /** do something when visiting a extern node during getToLeaf
00746      * 
00747      * Store the new index!
00748      */
00749 };
00750 
00751 //////////////////////////////////////////////////////////////////////////////
00752 // Out of Bag Error estimates                                               //
00753 //////////////////////////////////////////////////////////////////////////////
00754 
00755 
00756 /** Visitor that calculates the oob error of each individual randomized
00757  * decision tree. 
00758  *
00759  * After training a tree, all those samples that are OOB for this particular tree
00760  * are put down the tree and the error estimated. 
00761  * the per tree oob error is the average of the individual error estimates. 
00762  * (oobError = average error of one randomized tree)
00763  * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error 
00764  * visitor)
00765  */
00766 class OOB_PerTreeError:public VisitorBase
00767 {
00768 public:
00769     /** Average error of one randomized decision tree
00770      */
00771     double oobError;
00772 
00773     int totalOobCount;
00774     ArrayVector<int> oobCount,oobErrorCount;
00775 
00776     OOB_PerTreeError()
00777     : oobError(0.0),
00778       totalOobCount(0)
00779     {}
00780 
00781 
00782     bool has_value()
00783     {
00784         return true;
00785     }
00786 
00787 
00788     /** does the basic calculation per tree*/
00789     template<class RF, class PR, class SM, class ST>
00790     void visit_after_tree(    RF& rf, PR & pr,  SM & sm, ST & st, int index)
00791     {
00792         //do the first time called.
00793         if(int(oobCount.size()) != rf.ext_param_.row_count_)
00794         {
00795             oobCount.resize(rf.ext_param_.row_count_, 0);
00796             oobErrorCount.resize(rf.ext_param_.row_count_, 0);
00797         }
00798         // go through the samples
00799         for(int l = 0; l < rf.ext_param_.row_count_; ++l)
00800         {
00801             // if the lth sample is oob...
00802             if(!sm.is_used()[l])
00803             {
00804                 ++oobCount[l];
00805                 if(     rf.tree(index)
00806                             .predictLabel(rowVector(pr.features(), l)) 
00807                     !=  pr.response()(l,0))
00808                 {
00809                     ++oobErrorCount[l];
00810                 }
00811             }
00812 
00813         }
00814     }
00815 
00816     /** Does the normalisation
00817      */
00818     template<class RF, class PR>
00819     void visit_at_end(RF & rf, PR & pr)
00820     {
00821         // do some normalisation
00822         for(int l=0; l < (int)rf.ext_param_.row_count_; ++l)
00823         {
00824             if(oobCount[l])
00825             {
00826                 oobError += double(oobErrorCount[l]) / oobCount[l];
00827                 ++totalOobCount;
00828             }
00829         } 
00830         oobError/=totalOobCount;
00831     }
00832     
00833 };
00834 
00835 /** Visitor that calculates the oob error of the ensemble
00836  *  This rate should be used to estimate the crossvalidation 
00837  *  error rate.
00838  *  Here each sample is put down those trees, for which this sample
00839  *  is OOB i.e. if sample #1 is  OOB for trees 1, 3 and 5 we calculate
00840  *  the output using the ensemble consisting only of trees 1 3 and 5. 
00841  *
00842  *  Using normal bagged sampling each sample is OOB for approx. 33% of trees
00843  *  The error rate obtained as such therefore corresponds to crossvalidation
00844  *  rate obtained using a ensemble containing 33% of the trees.
00845  */
00846 class OOB_Error : public VisitorBase
00847 {
00848     typedef MultiArrayShape<2>::type Shp;
00849     int class_count;
00850     bool is_weighted;
00851     MultiArray<2,double> tmp_prob;
00852     public:
00853 
00854     MultiArray<2, double>       prob_oob; 
00855     /** Ensemble oob error rate
00856      */
00857     double                      oob_breiman;
00858 
00859     MultiArray<2, double>       oobCount;
00860     ArrayVector< int>           indices; 
00861     OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
00862 
00863     void save(std::string filen, std::string pathn)
00864     {
00865         if(*(pathn.end()-1) != '/')
00866             pathn += "/";
00867         const char* filename = filen.c_str();
00868         MultiArray<2, double> temp(Shp(1,1), 0.0); 
00869         temp[0] = oob_breiman;
00870         writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
00871     }
00872     // negative value if sample was ib, number indicates how often.
00873     //  value >=0  if sample was oob, 0 means fail 1, corrrect
00874 
00875     template<class RF, class PR>
00876     void visit_at_beginning(RF & rf, PR & pr)
00877     {
00878         class_count = rf.class_count();
00879         tmp_prob.reshape(Shp(1, class_count), 0); 
00880         prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
00881         is_weighted = rf.options().predict_weighted_;
00882         indices.resize(rf.ext_param().row_count_);
00883         if(int(oobCount.size()) != rf.ext_param_.row_count_)
00884         {
00885             oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
00886         }
00887         for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
00888         {
00889             indices[ii] = ii;
00890         }
00891     }
00892 
00893     template<class RF, class PR, class SM, class ST>
00894     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00895     {
00896         // go through the samples
00897         int total_oob =0;
00898         int wrong_oob =0;
00899         // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
00900         //                            (i.e. the OOB sample ist very large)
00901         //                     40000: use at most 40000 OOB samples per class for OOB error estimate 
00902         if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
00903         {
00904             ArrayVector<int> oob_indices;
00905             ArrayVector<int> cts(class_count, 0);
00906             std::random_shuffle(indices.begin(), indices.end());
00907             for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
00908             {
00909                 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
00910                 {
00911                     oob_indices.push_back(indices[ii]);
00912                     ++cts[pr.response()(indices[ii], 0)];
00913                 }
00914             }
00915             for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
00916             {
00917                 // update number of trees in which current sample is oob
00918                 ++oobCount[oob_indices[ll]];
00919 
00920                 // update number of oob samples in this tree.
00921                 ++total_oob; 
00922                 // get the predicted votes ---> tmp_prob;
00923                 int pos =  rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
00924                 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 
00925                                                     rf.tree(index).parameters_,
00926                                                     pos);
00927                 tmp_prob.init(0); 
00928                 for(int ii = 0; ii < class_count; ++ii)
00929                 {
00930                     tmp_prob[ii] = node.prob_begin()[ii];
00931                 }
00932                 if(is_weighted)
00933                 {
00934                     for(int ii = 0; ii < class_count; ++ii)
00935                         tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
00936                 }
00937                 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
00938                 int label = argMax(tmp_prob); 
00939                 
00940             }
00941         }else
00942         {
00943             for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
00944             {
00945                 // if the lth sample is oob...
00946                 if(!sm.is_used()[ll])
00947                 {
00948                     // update number of trees in which current sample is oob
00949                     ++oobCount[ll];
00950 
00951                     // update number of oob samples in this tree.
00952                     ++total_oob; 
00953                     // get the predicted votes ---> tmp_prob;
00954                     int pos =  rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
00955                     Node<e_ConstProbNode> node ( rf.tree(index).topology_, 
00956                                                         rf.tree(index).parameters_,
00957                                                         pos);
00958                     tmp_prob.init(0); 
00959                     for(int ii = 0; ii < class_count; ++ii)
00960                     {
00961                         tmp_prob[ii] = node.prob_begin()[ii];
00962                     }
00963                     if(is_weighted)
00964                     {
00965                         for(int ii = 0; ii < class_count; ++ii)
00966                             tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
00967                     }
00968                     rowVector(prob_oob, ll) += tmp_prob;
00969                     int label = argMax(tmp_prob); 
00970                     
00971                 }
00972             }
00973         }
00974         // go through the ib samples; 
00975     }
00976 
00977     /** Normalise variable importance after the number of trees is known.
00978      */
00979     template<class RF, class PR>
00980     void visit_at_end(RF & rf, PR & pr)
00981     {
00982         // ullis original metric and breiman style stuff
00983         int totalOobCount =0;
00984         int breimanstyle = 0;
00985         for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
00986         {
00987             if(oobCount[ll])
00988             {
00989                 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
00990                    ++breimanstyle;
00991                 ++totalOobCount;
00992             }
00993         }
00994         oob_breiman = double(breimanstyle)/totalOobCount; 
00995     }
00996 };
00997 
00998 
00999 /** Visitor that calculates different OOB error statistics
01000  */
01001 class CompleteOOBInfo : public VisitorBase
01002 {
01003     typedef MultiArrayShape<2>::type Shp;
01004     int class_count;
01005     bool is_weighted;
01006     MultiArray<2,double> tmp_prob;
01007     public:
01008 
01009     /** OOB Error rate of each individual tree
01010      */
01011     MultiArray<2, double>       oob_per_tree;
01012     /** Mean of oob_per_tree
01013      */
01014     double                      oob_mean;
01015     /**Standard deviation of oob_per_tree
01016      */
01017     double                      oob_std;
01018     
01019     MultiArray<2, double>       prob_oob; 
01020     /** Ensemble OOB error
01021      *
01022      * \sa OOB_Error
01023      */
01024     double                      oob_breiman;
01025 
01026     MultiArray<2, double>       oobCount;
01027     MultiArray<2, double>       oobErrorCount;
01028     /** Per Tree OOB error calculated as in OOB_PerTreeError
01029      * (Ulli's version)
01030      */
01031     double                      oob_per_tree2;
01032 
01033     /**Column containing the development of the Ensemble
01034      * error rate with increasing number of trees
01035      */
01036     MultiArray<2, double>       breiman_per_tree;
01037     /** 4 dimensional array containing the development of confusion matrices 
01038      * with number of trees - can be used to estimate ROC curves etc.
01039      *
01040      * oobroc_per_tree(ii,jj,kk,ll) 
01041      * corresponds true label = ii 
01042      * predicted label = jj
01043      * confusion matrix after ll trees
01044      *
01045      * explaination of third index:
01046      *
01047      * Two class case:
01048      * kk = 0 - (treeCount-1)
01049      *         Threshold is on Probability for class 0  is kk/(treeCount-1);
01050      * More classes:
01051      * kk = 0. Threshold on probability set by argMax of the probability array.
01052      */
01053     MultiArray<4, double>       oobroc_per_tree;
01054     
01055     CompleteOOBInfo() : VisitorBase(), oob_mean(0), oob_std(0), oob_per_tree2(0)  {}
01056 
01057     /** save to HDF5 file
01058      */
01059     void save(std::string filen, std::string pathn)
01060     {
01061         if(*(pathn.end()-1) != '/')
01062             pathn += "/";
01063         const char* filename = filen.c_str();
01064         MultiArray<2, double> temp(Shp(1,1), 0.0); 
01065         writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
01066         writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
01067         writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
01068         temp[0] = oob_mean;
01069         writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
01070         temp[0] = oob_std;
01071         writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
01072         temp[0] = oob_breiman;
01073         writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
01074         temp[0] = oob_per_tree2;
01075         writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
01076     }
01077     // negative value if sample was ib, number indicates how often.
01078     //  value >=0  if sample was oob, 0 means fail 1, corrrect
01079 
01080     template<class RF, class PR>
01081     void visit_at_beginning(RF & rf, PR & pr)
01082     {
01083         class_count = rf.class_count();
01084         if(class_count == 2)
01085             oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
01086         else
01087             oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
01088         tmp_prob.reshape(Shp(1, class_count), 0); 
01089         prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
01090         is_weighted = rf.options().predict_weighted_;
01091         oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
01092         breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
01093         //do the first time called.
01094         if(int(oobCount.size()) != rf.ext_param_.row_count_)
01095         {
01096             oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
01097             oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
01098         }
01099     }
01100 
01101     template<class RF, class PR, class SM, class ST>
01102     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01103     {
01104         // go through the samples
01105         int total_oob =0;
01106         int wrong_oob =0;
01107         for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
01108         {
01109             // if the lth sample is oob...
01110             if(!sm.is_used()[ll])
01111             {
01112                 // update number of trees in which current sample is oob
01113                 ++oobCount[ll];
01114 
01115                 // update number of oob samples in this tree.
01116                 ++total_oob; 
01117                 // get the predicted votes ---> tmp_prob;
01118                 int pos =  rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
01119                 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 
01120                                                     rf.tree(index).parameters_,
01121                                                     pos);
01122                 tmp_prob.init(0); 
01123                 for(int ii = 0; ii < class_count; ++ii)
01124                 {
01125                     tmp_prob[ii] = node.prob_begin()[ii];
01126                 }
01127                 if(is_weighted)
01128                 {
01129                     for(int ii = 0; ii < class_count; ++ii)
01130                         tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
01131                 }
01132                 rowVector(prob_oob, ll) += tmp_prob;
01133                 int label = argMax(tmp_prob); 
01134                 
01135                 if(label != pr.response()(ll, 0))
01136                 {
01137                     // update number of wrong oob samples in this tree.
01138                     ++wrong_oob;
01139                     // update number of trees in which current sample is wrong oob
01140                     ++oobErrorCount[ll];
01141                 }
01142             }
01143         }
01144         int breimanstyle = 0;
01145         int totalOobCount = 0;
01146         for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
01147         {
01148             if(oobCount[ll])
01149             {
01150                 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
01151                    ++breimanstyle;
01152                 ++totalOobCount;
01153                 if(oobroc_per_tree.shape(2) == 1)
01154                 {
01155                     oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
01156                 }
01157             }
01158         }
01159         if(oobroc_per_tree.shape(2) == 1)
01160             oobroc_per_tree.bindOuter(index)/=totalOobCount;
01161         if(oobroc_per_tree.shape(2) > 1)
01162         {
01163             MultiArrayView<3, double> current_roc 
01164                     = oobroc_per_tree.bindOuter(index);
01165             for(int gg = 0; gg < current_roc.shape(2); ++gg)
01166             {
01167                 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
01168                 {
01169                     if(oobCount[ll])
01170                     {
01171                         int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
01172                                         1 : 0; 
01173                         current_roc(pr.response()(ll, 0), pred, gg)+= 1; 
01174                     }
01175                 }
01176                 current_roc.bindOuter(gg)/= totalOobCount;
01177             }
01178         }
01179         breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
01180         oob_per_tree[index] = double(wrong_oob)/double(total_oob);
01181         // go through the ib samples; 
01182     }
01183 
01184     /** Normalise variable importance after the number of trees is known.
01185      */
01186     template<class RF, class PR>
01187     void visit_at_end(RF & rf, PR & pr)
01188     {
01189         // ullis original metric and breiman style stuff
01190         oob_per_tree2 = 0; 
01191         int totalOobCount =0;
01192         int breimanstyle = 0;
01193         for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
01194         {
01195             if(oobCount[ll])
01196             {
01197                 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
01198                    ++breimanstyle;
01199                 oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
01200                 ++totalOobCount;
01201             }
01202         }
01203         oob_per_tree2 /= totalOobCount; 
01204         oob_breiman = double(breimanstyle)/totalOobCount; 
01205         // mean error of each tree
01206         MultiArrayView<2, double> mean(Shp(1,1), &oob_mean);
01207         MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
01208         rowStatistics(oob_per_tree, mean, stdDev);
01209     }
01210 };
01211 
01212 /** calculate variable importance while learning.
01213  */
01214 class VariableImportanceVisitor : public VisitorBase
01215 {
01216     public:
01217 
01218     /** This Array has the same entries as the R - random forest variable
01219      *  importance.
01220      *  Matrix is   featureCount by (classCount +2)
01221      *  variable_importance_(ii,jj) is the variable importance measure of 
01222      *  the ii-th variable according to:
01223      *  jj = 0 - (classCount-1)
01224      *      classwise permutation importance 
01225      *  jj = rowCount(variable_importance_) -2
01226      *      permutation importance
01227      *  jj = rowCount(variable_importance_) -1
01228      *      gini decrease importance.
01229      *
01230      *  permutation importance:
01231      *  The difference between the fraction of OOB samples classified correctly
01232      *  before and after permuting (randomizing) the ii-th column is calculated.
01233      *  The ii-th column is permuted rep_cnt times.
01234      *
01235      *  class wise permutation importance:
01236      *  same as permutation importance. We only look at those OOB samples whose 
01237      *  response corresponds to class jj.
01238      *
01239      *  gini decrease importance:
01240      *  row ii corresponds to the sum of all gini decreases induced by variable ii 
01241      *  in each node of the random forest.
01242      */
01243     MultiArray<2, double>       variable_importance_;
01244     int                         repetition_count_;
01245     bool                        in_place_;
01246 
01247 #ifdef HasHDF5
01248     void save(std::string filename, std::string prefix)
01249     {
01250         prefix = "variable_importance_" + prefix;
01251         writeHDF5(filename.c_str(), 
01252                         prefix.c_str(), 
01253                         variable_importance_);
01254     }
01255 #endif
01256     /** Constructor
01257      * \param rep_cnt (defautl: 10) how often should 
01258      * the permutation take place. Set to 1 to make calculation faster (but
01259      * possibly more instable)
01260      */
01261     VariableImportanceVisitor(int rep_cnt = 10) 
01262     :   repetition_count_(rep_cnt)
01263 
01264     {}
01265 
01266     /** calculates impurity decrease based variable importance after every
01267      * split.  
01268      */
01269     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
01270     void visit_after_split( Tree          & tree, 
01271                             Split         & split,
01272                             Region        & parent,
01273                             Region        & leftChild,
01274                             Region        & rightChild,
01275                             Feature_t     & features,
01276                             Label_t       & labels)
01277     {
01278         //resize to right size when called the first time
01279         
01280         Int32 const  class_count = tree.ext_param_.class_count_;
01281         Int32 const  column_count = tree.ext_param_.column_count_;
01282         if(variable_importance_.size() == 0)
01283         {
01284             
01285             variable_importance_
01286                 .reshape(MultiArrayShape<2>::type(column_count, 
01287                                                  class_count+2));
01288         }
01289 
01290         if(split.createNode().typeID() == i_ThresholdNode)
01291         {
01292             Node<i_ThresholdNode> node(split.createNode());
01293             variable_importance_(node.column(),class_count+1) 
01294                 += split.region_gini_ - split.minGini();
01295         }
01296     }
01297 
01298     /**compute permutation based var imp. 
01299      * (Only an Array of size oob_sample_count x 1 is created.
01300      *  - apposed to oob_sample_count x feature_count in the other method.
01301      * 
01302      * \sa FieldProxy
01303      */
01304     template<class RF, class PR, class SM, class ST>
01305     void after_tree_ip_impl(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01306     {
01307         typedef MultiArrayShape<2>::type Shp_t;
01308         Int32                   column_count = rf.ext_param_.column_count_;
01309         Int32                   class_count  = rf.ext_param_.class_count_;  
01310         
01311         /* This solution saves memory uptake but not multithreading
01312          * compatible
01313          */
01314         // remove the const cast on the features (yep , I know what I am 
01315         // doing here.) data is not destroyed.
01316         //typename PR::Feature_t & features 
01317         //    = const_cast<typename PR::Feature_t &>(pr.features());
01318 
01319         typename PR::FeatureWithMemory_t features = pr.features();
01320 
01321         //find the oob indices of current tree. 
01322         ArrayVector<Int32>      oob_indices;
01323         ArrayVector<Int32>::iterator
01324                                 iter;
01325         for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
01326             if(!sm.is_used()[ii])
01327                 oob_indices.push_back(ii);
01328 
01329         //create space to back up a column      
01330         std::vector<double>     backup_column;
01331 
01332         // Random foo
01333 #ifdef CLASSIFIER_TEST
01334         RandomMT19937           random(1);
01335 #else 
01336         RandomMT19937           random(RandomSeed);
01337 #endif
01338         UniformIntRandomFunctor<RandomMT19937>  
01339                                 randint(random);
01340 
01341 
01342         //make some space for the results
01343         MultiArray<2, double>
01344                     oob_right(Shp_t(1, class_count + 1)); 
01345         MultiArray<2, double>
01346                     perm_oob_right (Shp_t(1, class_count + 1)); 
01347             
01348         
01349         // get the oob success rate with the original samples
01350         for(iter = oob_indices.begin(); 
01351             iter != oob_indices.end(); 
01352             ++iter)
01353         {
01354             if(rf.tree(index)
01355                     .predictLabel(rowVector(features, *iter)) 
01356                 ==  pr.response()(*iter, 0))
01357             {
01358                 //per class
01359                 ++oob_right[pr.response()(*iter,0)];
01360                 //total
01361                 ++oob_right[class_count];
01362             }
01363         }
01364         //get the oob rate after permuting the ii'th dimension.
01365         for(int ii = 0; ii < column_count; ++ii)
01366         {
01367             perm_oob_right.init(0.0); 
01368             //make backup of orinal column
01369             backup_column.clear();
01370             for(iter = oob_indices.begin(); 
01371                 iter != oob_indices.end(); 
01372                 ++iter)
01373             {
01374                 backup_column.push_back(features(*iter,ii));
01375             }
01376             
01377             //get the oob rate after permuting the ii'th dimension.
01378             for(int rr = 0; rr < repetition_count_; ++rr)
01379             {               
01380                 //permute dimension. 
01381                 int n = oob_indices.size();
01382                 for(int jj = 1; jj < n; ++jj)
01383                     std::swap(features(oob_indices[jj], ii), 
01384                               features(oob_indices[randint(jj+1)], ii));
01385 
01386                 //get the oob sucess rate after permuting
01387                 for(iter = oob_indices.begin(); 
01388                     iter != oob_indices.end(); 
01389                     ++iter)
01390                 {
01391                     if(rf.tree(index)
01392                             .predictLabel(rowVector(features, *iter)) 
01393                         ==  pr.response()(*iter, 0))
01394                     {
01395                         //per class
01396                         ++perm_oob_right[pr.response()(*iter, 0)];
01397                         //total
01398                         ++perm_oob_right[class_count];
01399                     }
01400                 }
01401             }
01402             
01403             
01404             //normalise and add to the variable_importance array.
01405             perm_oob_right  /=  repetition_count_;
01406             perm_oob_right -=oob_right;
01407             perm_oob_right *= -1;
01408             perm_oob_right      /=  oob_indices.size();
01409             variable_importance_
01410                 .subarray(Shp_t(ii,0), 
01411                           Shp_t(ii+1,class_count+1)) += perm_oob_right;
01412             //copy back permuted dimension
01413             for(int jj = 0; jj < int(oob_indices.size()); ++jj)
01414                 features(oob_indices[jj], ii) = backup_column[jj];
01415         }
01416     }
01417 
01418     /** calculate permutation based impurity after every tree has been 
01419      * learned  default behaviour is that this happens out of place.
01420      * If you have very big data sets and want to avoid copying of data 
01421      * set the in_place_ flag to true. 
01422      */
01423     template<class RF, class PR, class SM, class ST>
01424     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01425     {
01426             after_tree_ip_impl(rf, pr, sm, st, index);
01427     }
01428 
01429     /** Normalise variable importance after the number of trees is known.
01430      */
01431     template<class RF, class PR>
01432     void visit_at_end(RF & rf, PR & pr)
01433     {
01434         variable_importance_ /= rf.trees_.size();
01435     }
01436 };
01437 
01438 /** Verbose output
01439  */
01440 class RandomForestProgressVisitor : public VisitorBase {
01441     public:
01442     RandomForestProgressVisitor() : VisitorBase() {}
01443 
01444     template<class RF, class PR, class SM, class ST>
01445     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index){
01446         if(index != rf.options().tree_count_-1) {
01447             std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
01448                       << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
01449         }
01450         else {
01451             std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
01452         }
01453     }
01454     
01455     template<class RF, class PR>
01456     void visit_at_end(RF const & rf, PR const & pr) {
01457         std::string a = TOCS;
01458         std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a  << std::endl;
01459     }
01460     
01461     template<class RF, class PR>
01462     void visit_at_beginning(RF const & rf, PR const & pr) {
01463         TIC;
01464         std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
01465     }
01466     
01467     private:
01468     USETICTOC;
01469 };
01470 
01471 
01472 /** Computes Correlation/Similarity Matrix of features while learning
01473  * random forest.
01474  */
01475 class CorrelationVisitor : public VisitorBase
01476 {
01477     public:
01478     /** gini_missc(ii, jj) describes how well variable jj can describe a partition
01479      * created on variable ii(when variable ii was chosen)
01480      */ 
01481     MultiArray<2, double>   gini_missc;
01482     MultiArray<2, int>      tmp_labels;
01483     /** additional noise features. 
01484      */
01485     MultiArray<2, double>   noise;
01486     MultiArray<2, double>   noise_l;
01487     /** how well can a noise column describe a partition created on variable ii.
01488      */
01489     MultiArray<2, double>   corr_noise;
01490     MultiArray<2, double>   corr_l;
01491 
01492     /** Similarity Matrix
01493      * 
01494      * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
01495      * gini_missc 
01496      *  - row normalized by the number of times the column was chosen
01497      *  - mean of corr_noise subtracted
01498      *  - and symmetrised. 
01499      *          
01500      */
01501     MultiArray<2, double>   similarity;
01502     /** Distance Matrix 1-similarity
01503      */
01504     MultiArray<2, double>   distance;
01505     ArrayVector<int>        tmp_cc;
01506     
01507     /** How often was variable ii chosen
01508      */
01509     ArrayVector<int>        numChoices;
01510     typedef BestGiniOfColumn<GiniCriterion> ColumnDecisionFunctor;
01511     BestGiniOfColumn<GiniCriterion>         bgfunc;
01512     void save(std::string file, std::string prefix)
01513     {
01514         /*
01515         std::string tmp;
01516 #define VAR_WRITE(NAME) \
01517         tmp = #NAME;\
01518         tmp += "_";\
01519         tmp += prefix;\
01520         vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
01521         VAR_WRITE(gini_missc);
01522         VAR_WRITE(corr_noise);
01523         VAR_WRITE(distance);
01524         VAR_WRITE(similarity);
01525         vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
01526 #undef VAR_WRITE
01527 */
01528     }
01529     template<class RF, class PR>
01530     void visit_at_beginning(RF const & rf, PR  & pr)
01531     {
01532         typedef MultiArrayShape<2>::type Shp;
01533         int n = rf.ext_param_.column_count_;
01534         gini_missc.reshape(Shp(n +1,n+ 1));
01535         corr_noise.reshape(Shp(n + 1, 10));
01536         corr_l.reshape(Shp(n +1, 10));
01537 
01538         noise.reshape(Shp(pr.features().shape(0), 10));
01539         noise_l.reshape(Shp(pr.features().shape(0), 10));
01540         RandomMT19937 random(RandomSeed);
01541         for(int ii = 0; ii < noise.size(); ++ii)
01542         {
01543             noise[ii]   = random.uniform53();
01544             noise_l[ii] = random.uniform53()  > 0.5;
01545         }
01546         bgfunc = ColumnDecisionFunctor( rf.ext_param_);
01547         tmp_labels.reshape(pr.response().shape()); 
01548         tmp_cc.resize(2);
01549         numChoices.resize(n+1);
01550         // look at allaxes
01551     }
01552     template<class RF, class PR>
01553     void visit_at_end(RF const & rf, PR const & pr)
01554     {
01555         typedef MultiArrayShape<2>::type Shp;
01556         similarity.reshape(gini_missc.shape());
01557         similarity = gini_missc;;
01558         MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1));
01559         rowStatistics(corr_noise, mean_noise);
01560         mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data());        
01561         int rC = similarity.shape(0);
01562         for(int jj = 0; jj < rC-1; ++jj)
01563         {
01564             rowVector(similarity, jj) /= numChoices[jj];
01565             rowVector(similarity, jj) -= mean_noise(jj, 0);
01566         }
01567         for(int jj = 0; jj < rC; ++jj)
01568         {
01569             similarity(rC -1, jj) /= numChoices[jj];
01570         }
01571         rowVector(similarity, rC -  1) -= mean_noise(rC-1, 0);
01572         similarity = abs(similarity);
01573         FindMinMax<double> minmax;
01574         inspectMultiArray(srcMultiArrayRange(similarity), minmax);
01575         
01576         for(int jj = 0; jj < rC; ++jj)
01577             similarity(jj, jj) = minmax.max;
01578         
01579         similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)) 
01580             += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
01581         similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;  
01582         columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
01583         for(int jj = 0; jj < rC; ++jj)
01584             similarity(jj, jj) = 0;
01585         
01586         FindMinMax<double> minmax2;
01587         inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
01588         for(int jj = 0; jj < rC; ++jj)
01589             similarity(jj, jj) = minmax2.max;
01590         distance.reshape(gini_missc.shape(), minmax2.max);
01591         distance -= similarity; 
01592     }
01593 
01594     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
01595     void visit_after_split( Tree          & tree, 
01596                             Split         & split,
01597                             Region        & parent,
01598                             Region        & leftChild,
01599                             Region        & rightChild,
01600                             Feature_t     & features,
01601                             Label_t       & labels)
01602     {
01603         if(split.createNode().typeID() == i_ThresholdNode)
01604         {
01605             double wgini;
01606             tmp_cc.init(0); 
01607             for(int ii = 0; ii < parent.size(); ++ii)
01608             {
01609                 tmp_labels[parent[ii]] 
01610                     = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
01611                 ++tmp_cc[tmp_labels[parent[ii]]];
01612             }
01613             double region_gini = bgfunc.loss_of_region(tmp_labels, 
01614                                                        parent.begin(),
01615                                                        parent.end(),
01616                                                        tmp_cc);
01617 
01618             int n = split.bestSplitColumn(); 
01619             ++numChoices[n];
01620             ++(*(numChoices.end()-1));
01621             //this functor does all the work
01622             for(int k = 0; k < features.shape(1); ++k)
01623             {
01624                 bgfunc(columnVector(features, k),
01625                        0,
01626                        tmp_labels, 
01627                        parent.begin(), parent.end(), 
01628                        tmp_cc);
01629                 wgini = (region_gini - bgfunc.min_gini_);
01630                 gini_missc(n, k) 
01631                     += wgini;
01632             }
01633             for(int k = 0; k < 10; ++k)
01634             {
01635                 bgfunc(columnVector(noise, k),
01636                        0,
01637                        tmp_labels, 
01638                        parent.begin(), parent.end(), 
01639                        tmp_cc);
01640                 wgini = (region_gini - bgfunc.min_gini_);
01641                 corr_noise(n, k) 
01642                     += wgini;
01643             }
01644             
01645             for(int k = 0; k < 10; ++k)
01646             {
01647                 bgfunc(columnVector(noise_l, k),
01648                        0,
01649                        tmp_labels, 
01650                        parent.begin(), parent.end(), 
01651                        tmp_cc);
01652                 wgini = (region_gini - bgfunc.min_gini_);
01653                 corr_l(n, k) 
01654                     += wgini;
01655             }
01656             bgfunc(labels,0,  tmp_labels, parent.begin(), parent.end(),tmp_cc);
01657             wgini = (region_gini - bgfunc.min_gini_);
01658             gini_missc(n, columnCount(gini_missc)-1) 
01659                 += wgini;
01660             
01661             region_gini = split.region_gini_;
01662 #if 1 
01663             Node<i_ThresholdNode> node(split.createNode());
01664             gini_missc(rowCount(gini_missc)-1, 
01665                                   node.column()) 
01666                  +=split.region_gini_ - split.minGini();
01667 #endif
01668             for(int k = 0; k < 10; ++k)
01669             {
01670                 split.bgfunc(columnVector(noise, k),
01671                              0,
01672                              labels, 
01673                              parent.begin(), parent.end(), 
01674                              parent.classCounts());
01675                 corr_noise(rowCount(gini_missc)-1, 
01676                            k) 
01677                      += wgini;
01678             }
01679 #if 0
01680             for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
01681             {
01682                 wgini = region_gini - split.min_gini_[k];
01683                 
01684                 gini_missc(rowCount(gini_missc)-1, 
01685                                       split.splitColumns[k]) 
01686                      += wgini;
01687             }
01688             
01689             for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
01690             {
01691                 split.bgfunc(columnVector(features, split.splitColumns[k]),
01692                              labels, 
01693                              parent.begin(), parent.end(), 
01694                              parent.classCounts());
01695                 wgini = region_gini - split.bgfunc.min_gini_;
01696                 gini_missc(rowCount(gini_missc)-1, 
01697                                       split.splitColumns[k]) += wgini;
01698             }
01699 #endif
01700             // remember to partition the data according to the best.
01701                 gini_missc(rowCount(gini_missc)-1, 
01702                            columnCount(gini_missc)-1) 
01703                      += region_gini;
01704                 SortSamplesByDimensions<Feature_t> 
01705                 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
01706             std::partition(parent.begin(), parent.end(), sorter);
01707         }
01708     }
01709 };
01710 
01711 
01712 } // namespace visitors
01713 } // namespace rf
01714 } // namespace vigra
01715 
01716 //@}
01717 #endif // RF_VISITORS_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Mon Apr 16 2012)