[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest_hdf5_impex.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*       Copyright 2009 by Rahul Nair and  Ullrich Koethe               */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 
00037 #ifndef VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX
00038 #define VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX
00039 
00040 #include "config.hxx"
00041 #include "random_forest.hxx"
00042 #include "hdf5impex.hxx"
00043 #include <cstdio>
00044 #include <string>
00045 
00046 #ifdef HasHDF5
00047 
00048 namespace vigra 
00049 {
00050 
00051 namespace detail
00052 {
00053 
00054 
00055 /** shallow search the hdf5 group for containing elements
00056  * returns negative value if unsuccessful
00057  * \param grp_id    hid_t containing path to group.
00058  * \param cont        reference to container that supports
00059  *                     insert(). valuetype of cont must be
00060  *                     std::string
00061  */
00062 template<class Container>
00063 bool find_groups_hdf5(hid_t grp_id, Container &cont)
00064 {
00065     
00066     //get group info
00067 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6)
00068     hsize_t size;
00069     H5Gget_num_objs(grp_id, &size);
00070 #else
00071     hsize_t size;
00072     H5G_info_t ginfo;
00073     herr_t         status;    
00074     status = H5Gget_info (grp_id , &ginfo);
00075     if(status < 0)
00076         std::runtime_error("find_groups_hdf5():"
00077                            "problem while getting group info");
00078     size = ginfo.nlinks;
00079 #endif
00080     for(hsize_t ii = 0; ii < size; ++ii)
00081     {
00082 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6)
00083         ssize_t buffer_size = 
00084                 H5Gget_objname_by_idx(grp_id, 
00085                                       ii, NULL, 0 ) + 1;
00086 #else
00087         std::ptrdiff_t buffer_size =
00088                 H5Lget_name_by_idx(grp_id, ".",
00089                                    H5_INDEX_NAME,
00090                                    H5_ITER_INC,
00091                                    ii, 0, 0, H5P_DEFAULT)+1;
00092 #endif
00093         ArrayVector<char> buffer(buffer_size);
00094 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6)
00095         buffer_size = 
00096                 H5Gget_objname_by_idx(grp_id, 
00097                                       ii, buffer.data(), 
00098                                       (size_t)buffer_size );
00099 #else
00100         buffer_size =
00101                 H5Lget_name_by_idx(grp_id, ".",
00102                                    H5_INDEX_NAME,
00103                                    H5_ITER_INC,
00104                                    ii, buffer.data(),
00105                                    (size_t)buffer_size,
00106                                    H5P_DEFAULT);
00107 #endif
00108         cont.insert(cont.end(), std::string(buffer.data()));
00109     }
00110     return true;
00111 }
00112 
00113 
00114 /** shallow search the hdf5 group for containing elements
00115  * returns negative value if unsuccessful
00116  * \param filename name of hdf5 file
00117  * \param groupname path in hdf5 file
00118  * \param cont        reference to container that supports
00119  *                     insert(). valuetype of cont must be
00120  *                     std::string
00121  */
00122 template<class Container>
00123 bool find_groups_hdf5(std::string filename, 
00124                               std::string groupname, 
00125                               Container &cont)
00126 {
00127     //check if file exists
00128     FILE* pFile;
00129     pFile = std::fopen ( filename.c_str(), "r" );
00130     if ( pFile == NULL)
00131     {    
00132         return 0;
00133     }
00134     std::fclose(pFile);
00135     //open the file
00136     HDF5Handle file_id(H5Fopen(filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT),
00137                        &H5Fclose, "Unable to open HDF5 file");
00138     HDF5Handle grp_id;
00139     if(groupname == "")
00140     {
00141         grp_id = HDF5Handle(file_id, 0, "");
00142     }
00143     else
00144     {
00145         grp_id = HDF5Handle(H5Gopen(file_id, groupname.c_str(), H5P_DEFAULT),
00146                             &H5Gclose, "Unable to open group");
00147 
00148     }
00149     bool res =  find_groups_hdf5(grp_id, cont); 
00150     return res; 
00151 }
00152 
00153 VIGRA_EXPORT int get_number_of_digits(int in);
00154 
00155 VIGRA_EXPORT std::string make_padded_number(int number, int max_number);
00156 
00157 /** write a ArrayVector to a hdf5 dataset.
00158  */
00159 template<class U, class T>
00160 void write_array_2_hdf5(hid_t & id, 
00161                         ArrayVector<U> const & arr, 
00162                         std::string    const & name, 
00163                         T  type) 
00164 {
00165     hsize_t size = arr.size(); 
00166     vigra_postcondition(H5LTmake_dataset (id, 
00167                                           name.c_str(), 
00168                                           1, 
00169                                           &size, 
00170                                           type, 
00171                                           arr.begin()) 
00172                         >= 0,
00173                         "write_array_2_hdf5():"
00174                         "unable to write dataset");
00175 }
00176 
00177 
00178 template<class U, class T>
00179 void write_hdf5_2_array(hid_t & id, 
00180                         ArrayVector<U>       & arr, 
00181                         std::string    const & name, 
00182                         T  type) 
00183 {    
00184     // The last three values of get_dataset_info can be NULL
00185     // my EFFING FOOT! that is valid for HDF5 1.8 but not for
00186     // 1.6 - but documented the other way around AAARRHGHGHH
00187     hsize_t size; 
00188     H5T_class_t a; 
00189     size_t b;
00190     vigra_postcondition(H5LTget_dataset_info(id, 
00191                                              name.c_str(), 
00192                                              &size, 
00193                                              &a, 
00194                                              &b) >= 0,
00195                         "write_hdf5_2_array(): "
00196                         "Unable to locate dataset");
00197     arr.resize((typename ArrayVector<U>::size_type)size);
00198     vigra_postcondition(H5LTread_dataset (id, 
00199                                           name.c_str(),
00200                                           type, 
00201                                           arr.data()) >= 0,
00202                         "write_array_2_hdf5():"
00203                         "unable to read dataset");
00204 }
00205 
00206 /*
00207 inline void options_import_HDF5(hid_t & group_id,
00208                          RandomForestOptions & opt, 
00209                          std::string name)
00210 {
00211     ArrayVector<double> serialized_options;
00212     write_hdf5_2_array(group_id, serialized_options,
00213                           name, H5T_NATIVE_DOUBLE); 
00214     opt.unserialize(serialized_options.begin(),
00215                       serialized_options.end());
00216 }
00217 
00218 inline void options_export_HDF5(hid_t & group_id,
00219                          RandomForestOptions const & opt, 
00220                          std::string name)
00221 {
00222     ArrayVector<double> serialized_options(opt.serialized_size());
00223     opt.serialize(serialized_options.begin(),
00224                   serialized_options.end());
00225     write_array_2_hdf5(group_id, serialized_options,
00226                       name, H5T_NATIVE_DOUBLE); 
00227 }
00228 */
00229 
00230 struct MyT
00231 {
00232     enum type { INT8 = 1,  INT16 = 2,  INT32 =3,  INT64=4, 
00233                   UINT8 = 5, UINT16 = 6, UINT32= 7, UINT64= 8,
00234                   FLOAT = 9, DOUBLE = 10, OTHER = 3294};
00235 };
00236 
00237 
00238 
00239 #define create_type_of(TYPE, ENUM) \
00240 inline MyT::type type_of(TYPE)\
00241 {\
00242     return MyT::ENUM; \
00243 }
00244 create_type_of(Int8, INT8)
00245 create_type_of(Int16, INT16)
00246 create_type_of(Int32, INT32)
00247 create_type_of(Int64, INT64)
00248 create_type_of(UInt8, UINT8)
00249 create_type_of(UInt16, UINT16)
00250 create_type_of(UInt32, UINT32)
00251 create_type_of(UInt64, UINT64)
00252 create_type_of(float, FLOAT)
00253 create_type_of(double, DOUBLE)
00254 #undef create_type_of
00255 
00256 VIGRA_EXPORT MyT::type type_of_hid_t(hid_t group_id, std::string name);
00257 
00258 VIGRA_EXPORT void options_import_HDF5(hid_t & group_id, 
00259                         RandomForestOptions  & opt, 
00260                         std::string name);
00261 
00262 VIGRA_EXPORT void options_export_HDF5(hid_t & group_id, 
00263                          RandomForestOptions const & opt, 
00264                          std::string name);
00265 
00266 template<class T>
00267 void problemspec_import_HDF5(hid_t & group_id, 
00268                              ProblemSpec<T>  & param, 
00269                              std::string name)
00270 {
00271     hid_t param_id = H5Gopen (group_id, 
00272                               name.c_str(), 
00273                               H5P_DEFAULT);
00274 
00275     vigra_postcondition(param_id >= 0, 
00276                         "problemspec_import_HDF5():"
00277                         " Unable to open external parameters");
00278 
00279     //get a map containing all the double fields
00280     std::set<std::string> ext_set;
00281     find_groups_hdf5(param_id, ext_set);
00282     std::map<std::string, ArrayVector <double> > ext_map;
00283     std::set<std::string>::iterator iter;
00284     if(ext_set.find(std::string("labels")) == ext_set.end())
00285         std::runtime_error("labels are missing");
00286     for(iter = ext_set.begin(); iter != ext_set.end(); ++ iter)
00287     {
00288         if(*iter != std::string("labels"))
00289         {
00290             ext_map[*iter] = ArrayVector<double>();
00291             write_hdf5_2_array(param_id, ext_map[*iter], 
00292                                *iter, H5T_NATIVE_DOUBLE);
00293         }
00294     }
00295     param.make_from_map(ext_map);
00296     //load_class_labels
00297     switch(type_of_hid_t(param_id,"labels" ))
00298     {
00299         #define SOME_CASE(type_, enum_) \
00300       case MyT::enum_ :\
00301         {\
00302             ArrayVector<type_> tmp;\
00303             write_hdf5_2_array(param_id, tmp, "labels", H5T_NATIVE_##enum_);\
00304             param.classes_(tmp.begin(), tmp.end());\
00305         }\
00306             break;
00307         SOME_CASE(UInt8,     UINT8);
00308         SOME_CASE(UInt16,     UINT16);
00309         SOME_CASE(UInt32,     UINT32);
00310         SOME_CASE(UInt64,     UINT64);
00311         SOME_CASE(Int8,      INT8);
00312         SOME_CASE(Int16,     INT16);
00313         SOME_CASE(Int32,     INT32);
00314         SOME_CASE(Int64,     INT64);
00315         SOME_CASE(double,     DOUBLE);
00316         SOME_CASE(float,     FLOAT);
00317         default:
00318             std::runtime_error("exportRF_HDF5(): unknown class type"); 
00319         #undef SOME_CASE
00320     }
00321     H5Gclose(param_id);
00322 }
00323 
00324 template<class T>
00325 void problemspec_export_HDF5(hid_t & group_id, 
00326                              ProblemSpec<T> const & param, 
00327                              std::string name)
00328 {
00329     hid_t param_id = H5Gcreate(group_id, name.c_str(), 
00330                                            H5P_DEFAULT, 
00331                                            H5P_DEFAULT, 
00332                                         H5P_DEFAULT);
00333     vigra_postcondition(param_id >= 0, 
00334                         "problemspec_export_HDF5():"
00335                         " Unable to create external parameters");
00336 
00337     //get a map containing all the double fields
00338     std::map<std::string, ArrayVector<double> > serialized_param;
00339     param.make_map(serialized_param);
00340     std::map<std::string, ArrayVector<double> >::iterator iter;
00341     for(iter = serialized_param.begin(); iter != serialized_param.end(); ++iter)
00342         write_array_2_hdf5(param_id, iter->second, iter->first, H5T_NATIVE_DOUBLE);
00343     
00344     //save class_labels
00345     switch(type_of(param.classes[0]))
00346     {
00347         #define SOME_CASE(type) \
00348         case MyT::type:\
00349             write_array_2_hdf5(param_id, param.classes, "labels", H5T_NATIVE_##type);\
00350             break;
00351         SOME_CASE(UINT8);
00352         SOME_CASE(UINT16);
00353         SOME_CASE(UINT32);
00354         SOME_CASE(UINT64);
00355         SOME_CASE(INT8);
00356         SOME_CASE(INT16);
00357         SOME_CASE(INT32);
00358         SOME_CASE(INT64);
00359         SOME_CASE(DOUBLE);
00360         SOME_CASE(FLOAT);
00361         default:
00362             std::runtime_error("exportRF_HDF5(): unknown class type"); 
00363         #undef SOME_CASE
00364     }
00365     H5Gclose(param_id);
00366 }
00367 
00368 VIGRA_EXPORT void dt_import_HDF5(hid_t & group_id,
00369                     detail::DecisionTree & tree,
00370                     std::string name);
00371 
00372 
00373 VIGRA_EXPORT void dt_export_HDF5(hid_t & group_id,
00374                     detail::DecisionTree const & tree,
00375                     std::string name);
00376                     
00377 } //namespace detail
00378 
00379 template<class T>
00380 bool rf_export_HDF5(RandomForest<T> const &rf, 
00381                     std::string filename, 
00382                     std::string pathname = "",
00383                     bool overwriteflag = false)
00384 { 
00385     using detail::make_padded_number;
00386     using detail::options_export_HDF5;
00387     using detail::problemspec_export_HDF5;
00388     using detail::dt_export_HDF5;
00389     
00390     hid_t file_id;
00391     //if file exists load it.
00392     FILE* pFile = std::fopen ( filename.c_str(), "r" );
00393     if ( pFile != NULL)
00394     {    
00395         std::fclose(pFile);
00396         file_id = H5Fopen(filename.c_str(), H5F_ACC_RDWR, 
00397                                                 H5P_DEFAULT);
00398     }
00399     else
00400     {
00401         //create a new file.
00402         file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, 
00403                                                     H5P_DEFAULT, 
00404                                                     H5P_DEFAULT);
00405     }
00406     vigra_postcondition(file_id >= 0, 
00407                         "rf_export_HDF5(): Unable to open file.");
00408     //std::cerr << pathname.c_str()
00409 
00410     //if the group already exists this will cause an error
00411     //we will have to use the overwriteflag to check for 
00412     //this, but i dont know how to delete groups...
00413 
00414     hid_t group_id = pathname== "" ?
00415                         file_id
00416                     :    H5Gcreate(file_id, pathname.c_str(), 
00417                                               H5P_DEFAULT, 
00418                                               H5P_DEFAULT, 
00419                                            H5P_DEFAULT);
00420 
00421     vigra_postcondition(group_id >= 0, 
00422                         "rf_export_HDF5(): Unable to create group");
00423 
00424     //save serialized options
00425         options_export_HDF5(group_id, rf.options(), "_options"); 
00426     //save external parameters
00427         problemspec_export_HDF5(group_id, rf.ext_param(), "_ext_param");
00428     //save trees
00429     
00430     int tree_count = rf.options_.tree_count_;
00431     for(int ii = 0; ii < tree_count; ++ii)
00432     {
00433         std::string treename =     "Tree_"  + 
00434                                 make_padded_number(ii, tree_count -1);
00435         dt_export_HDF5(group_id, rf.tree(ii), treename); 
00436     }
00437     
00438     //clean up the mess
00439     if(pathname != "")
00440         H5Gclose(group_id);
00441     H5Fclose(file_id);
00442 
00443     return 1;
00444 }
00445 
00446 
00447 template<class T>
00448 bool rf_import_HDF5(RandomForest<T> &rf, 
00449                     std::string filename, 
00450                     std::string pathname = "")
00451 { 
00452     using detail::find_groups_hdf5;
00453     using detail::options_import_HDF5;
00454     using detail::problemspec_import_HDF5;
00455     using detail::dt_export_HDF5;
00456     // check if file exists
00457     FILE* pFile = std::fopen ( filename.c_str(), "r" );
00458     if ( pFile == NULL)
00459         return 0;
00460     std::fclose(pFile);
00461     //open file
00462     hid_t file_id = H5Fopen (filename.c_str(), 
00463                              H5F_ACC_RDONLY, 
00464                              H5P_DEFAULT);
00465     
00466     vigra_postcondition(file_id >= 0, 
00467                         "rf_import_HDF5(): Unable to open file.");
00468     hid_t group_id = pathname== "" ?
00469                         file_id
00470                     :    H5Gopen (file_id, 
00471                                  pathname.c_str(), 
00472                                  H5P_DEFAULT);
00473     
00474     vigra_postcondition(group_id >= 0, 
00475                         "rf_export_HDF5(): Unable to create group");
00476 
00477     //get serialized options
00478         options_import_HDF5(group_id, rf.options_, "_options"); 
00479     //save external parameters
00480         problemspec_import_HDF5(group_id, rf.ext_param_, "_ext_param");
00481     // TREE SAVING TIME
00482     // get all groups in base path
00483     
00484     std::set<std::string> tree_set;
00485     std::set<std::string>::iterator iter; 
00486     find_groups_hdf5(filename, pathname, tree_set);
00487     
00488     for(iter = tree_set.begin(); iter != tree_set.end(); ++iter)
00489     {
00490         if((*iter)[0] != '_')
00491         {
00492             rf.trees_.push_back(detail::DecisionTree(rf.ext_param_));
00493             dt_import_HDF5(group_id, rf.trees_.back(), *iter); 
00494         }
00495     }
00496     
00497     //clean up the mess
00498     if(pathname != "")
00499         H5Gclose(group_id);
00500     H5Fclose(file_id);
00501     /*rf.tree_indices_.resize(rf.tree_count());
00502     for(int ii = 0; ii < rf.tree_count(); ++ii)
00503         rf.tree_indices_[ii] = ii; */
00504     return 1;
00505 }
00506 } // namespace vigra
00507 
00508 #endif // HasHDF5
00509 
00510 #endif // VIGRA_RANDOM_FOREST_HDF5_IMPEX_HXX
00511 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Mon Apr 16 2012)