[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2009 by Rahul Nair and Ullrich Koethe */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 00037 #ifndef VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX 00038 #define VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX 00039 00040 #include "config.hxx" 00041 #include "random_forest.hxx" 00042 #include "hdf5impex.hxx" 00043 #include <cstdio> 00044 #include <string> 00045 00046 #ifdef HasHDF5 00047 00048 namespace vigra 00049 { 00050 00051 namespace detail 00052 { 00053 00054 00055 /** shallow search the hdf5 group for containing elements 00056 * returns negative value if unsuccessful 00057 * \param grp_id hid_t containing path to group. 00058 * \param cont reference to container that supports 00059 * insert(). valuetype of cont must be 00060 * std::string 00061 */ 00062 template<class Container> 00063 bool find_groups_hdf5(hid_t grp_id, Container &cont) 00064 { 00065 00066 //get group info 00067 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6) 00068 hsize_t size; 00069 H5Gget_num_objs(grp_id, &size); 00070 #else 00071 hsize_t size; 00072 H5G_info_t ginfo; 00073 herr_t status; 00074 status = H5Gget_info (grp_id , &ginfo); 00075 if(status < 0) 00076 std::runtime_error("find_groups_hdf5():" 00077 "problem while getting group info"); 00078 size = ginfo.nlinks; 00079 #endif 00080 for(hsize_t ii = 0; ii < size; ++ii) 00081 { 00082 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6) 00083 ssize_t buffer_size = 00084 H5Gget_objname_by_idx(grp_id, 00085 ii, NULL, 0 ) + 1; 00086 #else 00087 std::ptrdiff_t buffer_size = 00088 H5Lget_name_by_idx(grp_id, ".", 00089 H5_INDEX_NAME, 00090 H5_ITER_INC, 00091 ii, 0, 0, H5P_DEFAULT)+1; 00092 #endif 00093 ArrayVector<char> buffer(buffer_size); 00094 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6) 00095 buffer_size = 00096 H5Gget_objname_by_idx(grp_id, 00097 ii, buffer.data(), 00098 (size_t)buffer_size ); 00099 #else 00100 buffer_size = 00101 H5Lget_name_by_idx(grp_id, ".", 00102 H5_INDEX_NAME, 00103 H5_ITER_INC, 00104 ii, buffer.data(), 00105 (size_t)buffer_size, 00106 H5P_DEFAULT); 00107 #endif 00108 cont.insert(cont.end(), std::string(buffer.data())); 00109 } 00110 return true; 00111 } 00112 00113 00114 /** shallow search the hdf5 group for containing elements 00115 * returns negative value if unsuccessful 00116 * \param filename name of hdf5 file 00117 * \param groupname path in hdf5 file 00118 * \param cont reference to container that supports 00119 * insert(). valuetype of cont must be 00120 * std::string 00121 */ 00122 template<class Container> 00123 bool find_groups_hdf5(std::string filename, 00124 std::string groupname, 00125 Container &cont) 00126 { 00127 //check if file exists 00128 FILE* pFile; 00129 pFile = std::fopen ( filename.c_str(), "r" ); 00130 if ( pFile == NULL) 00131 { 00132 return 0; 00133 } 00134 std::fclose(pFile); 00135 //open the file 00136 HDF5Handle file_id(H5Fopen(filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT), 00137 &H5Fclose, "Unable to open HDF5 file"); 00138 HDF5Handle grp_id; 00139 if(groupname == "") 00140 { 00141 grp_id = HDF5Handle(file_id, 0, ""); 00142 } 00143 else 00144 { 00145 grp_id = HDF5Handle(H5Gopen(file_id, groupname.c_str(), H5P_DEFAULT), 00146 &H5Gclose, "Unable to open group"); 00147 00148 } 00149 bool res = find_groups_hdf5(grp_id, cont); 00150 return res; 00151 } 00152 00153 VIGRA_EXPORT int get_number_of_digits(int in); 00154 00155 VIGRA_EXPORT std::string make_padded_number(int number, int max_number); 00156 00157 /** write a ArrayVector to a hdf5 dataset. 00158 */ 00159 template<class U, class T> 00160 void write_array_2_hdf5(hid_t & id, 00161 ArrayVector<U> const & arr, 00162 std::string const & name, 00163 T type) 00164 { 00165 hsize_t size = arr.size(); 00166 vigra_postcondition(H5LTmake_dataset (id, 00167 name.c_str(), 00168 1, 00169 &size, 00170 type, 00171 arr.begin()) 00172 >= 0, 00173 "write_array_2_hdf5():" 00174 "unable to write dataset"); 00175 } 00176 00177 00178 template<class U, class T> 00179 void write_hdf5_2_array(hid_t & id, 00180 ArrayVector<U> & arr, 00181 std::string const & name, 00182 T type) 00183 { 00184 // The last three values of get_dataset_info can be NULL 00185 // my EFFING FOOT! that is valid for HDF5 1.8 but not for 00186 // 1.6 - but documented the other way around AAARRHGHGHH 00187 hsize_t size; 00188 H5T_class_t a; 00189 size_t b; 00190 vigra_postcondition(H5LTget_dataset_info(id, 00191 name.c_str(), 00192 &size, 00193 &a, 00194 &b) >= 0, 00195 "write_hdf5_2_array(): " 00196 "Unable to locate dataset"); 00197 arr.resize((typename ArrayVector<U>::size_type)size); 00198 vigra_postcondition(H5LTread_dataset (id, 00199 name.c_str(), 00200 type, 00201 arr.data()) >= 0, 00202 "write_array_2_hdf5():" 00203 "unable to read dataset"); 00204 } 00205 00206 /* 00207 inline void options_import_HDF5(hid_t & group_id, 00208 RandomForestOptions & opt, 00209 std::string name) 00210 { 00211 ArrayVector<double> serialized_options; 00212 write_hdf5_2_array(group_id, serialized_options, 00213 name, H5T_NATIVE_DOUBLE); 00214 opt.unserialize(serialized_options.begin(), 00215 serialized_options.end()); 00216 } 00217 00218 inline void options_export_HDF5(hid_t & group_id, 00219 RandomForestOptions const & opt, 00220 std::string name) 00221 { 00222 ArrayVector<double> serialized_options(opt.serialized_size()); 00223 opt.serialize(serialized_options.begin(), 00224 serialized_options.end()); 00225 write_array_2_hdf5(group_id, serialized_options, 00226 name, H5T_NATIVE_DOUBLE); 00227 } 00228 */ 00229 00230 struct MyT 00231 { 00232 enum type { INT8 = 1, INT16 = 2, INT32 =3, INT64=4, 00233 UINT8 = 5, UINT16 = 6, UINT32= 7, UINT64= 8, 00234 FLOAT = 9, DOUBLE = 10, OTHER = 3294}; 00235 }; 00236 00237 00238 00239 #define create_type_of(TYPE, ENUM) \ 00240 inline MyT::type type_of(TYPE)\ 00241 {\ 00242 return MyT::ENUM; \ 00243 } 00244 create_type_of(Int8, INT8) 00245 create_type_of(Int16, INT16) 00246 create_type_of(Int32, INT32) 00247 create_type_of(Int64, INT64) 00248 create_type_of(UInt8, UINT8) 00249 create_type_of(UInt16, UINT16) 00250 create_type_of(UInt32, UINT32) 00251 create_type_of(UInt64, UINT64) 00252 create_type_of(float, FLOAT) 00253 create_type_of(double, DOUBLE) 00254 #undef create_type_of 00255 00256 VIGRA_EXPORT MyT::type type_of_hid_t(hid_t group_id, std::string name); 00257 00258 VIGRA_EXPORT void options_import_HDF5(hid_t & group_id, 00259 RandomForestOptions & opt, 00260 std::string name); 00261 00262 VIGRA_EXPORT void options_export_HDF5(hid_t & group_id, 00263 RandomForestOptions const & opt, 00264 std::string name); 00265 00266 template<class T> 00267 void problemspec_import_HDF5(hid_t & group_id, 00268 ProblemSpec<T> & param, 00269 std::string name) 00270 { 00271 hid_t param_id = H5Gopen (group_id, 00272 name.c_str(), 00273 H5P_DEFAULT); 00274 00275 vigra_postcondition(param_id >= 0, 00276 "problemspec_import_HDF5():" 00277 " Unable to open external parameters"); 00278 00279 //get a map containing all the double fields 00280 std::set<std::string> ext_set; 00281 find_groups_hdf5(param_id, ext_set); 00282 std::map<std::string, ArrayVector <double> > ext_map; 00283 std::set<std::string>::iterator iter; 00284 if(ext_set.find(std::string("labels")) == ext_set.end()) 00285 std::runtime_error("labels are missing"); 00286 for(iter = ext_set.begin(); iter != ext_set.end(); ++ iter) 00287 { 00288 if(*iter != std::string("labels")) 00289 { 00290 ext_map[*iter] = ArrayVector<double>(); 00291 write_hdf5_2_array(param_id, ext_map[*iter], 00292 *iter, H5T_NATIVE_DOUBLE); 00293 } 00294 } 00295 param.make_from_map(ext_map); 00296 //load_class_labels 00297 switch(type_of_hid_t(param_id,"labels" )) 00298 { 00299 #define SOME_CASE(type_, enum_) \ 00300 case MyT::enum_ :\ 00301 {\ 00302 ArrayVector<type_> tmp;\ 00303 write_hdf5_2_array(param_id, tmp, "labels", H5T_NATIVE_##enum_);\ 00304 param.classes_(tmp.begin(), tmp.end());\ 00305 }\ 00306 break; 00307 SOME_CASE(UInt8, UINT8); 00308 SOME_CASE(UInt16, UINT16); 00309 SOME_CASE(UInt32, UINT32); 00310 SOME_CASE(UInt64, UINT64); 00311 SOME_CASE(Int8, INT8); 00312 SOME_CASE(Int16, INT16); 00313 SOME_CASE(Int32, INT32); 00314 SOME_CASE(Int64, INT64); 00315 SOME_CASE(double, DOUBLE); 00316 SOME_CASE(float, FLOAT); 00317 default: 00318 std::runtime_error("exportRF_HDF5(): unknown class type"); 00319 #undef SOME_CASE 00320 } 00321 H5Gclose(param_id); 00322 } 00323 00324 template<class T> 00325 void problemspec_export_HDF5(hid_t & group_id, 00326 ProblemSpec<T> const & param, 00327 std::string name) 00328 { 00329 hid_t param_id = H5Gcreate(group_id, name.c_str(), 00330 H5P_DEFAULT, 00331 H5P_DEFAULT, 00332 H5P_DEFAULT); 00333 vigra_postcondition(param_id >= 0, 00334 "problemspec_export_HDF5():" 00335 " Unable to create external parameters"); 00336 00337 //get a map containing all the double fields 00338 std::map<std::string, ArrayVector<double> > serialized_param; 00339 param.make_map(serialized_param); 00340 std::map<std::string, ArrayVector<double> >::iterator iter; 00341 for(iter = serialized_param.begin(); iter != serialized_param.end(); ++iter) 00342 write_array_2_hdf5(param_id, iter->second, iter->first, H5T_NATIVE_DOUBLE); 00343 00344 //save class_labels 00345 switch(type_of(param.classes[0])) 00346 { 00347 #define SOME_CASE(type) \ 00348 case MyT::type:\ 00349 write_array_2_hdf5(param_id, param.classes, "labels", H5T_NATIVE_##type);\ 00350 break; 00351 SOME_CASE(UINT8); 00352 SOME_CASE(UINT16); 00353 SOME_CASE(UINT32); 00354 SOME_CASE(UINT64); 00355 SOME_CASE(INT8); 00356 SOME_CASE(INT16); 00357 SOME_CASE(INT32); 00358 SOME_CASE(INT64); 00359 SOME_CASE(DOUBLE); 00360 SOME_CASE(FLOAT); 00361 default: 00362 std::runtime_error("exportRF_HDF5(): unknown class type"); 00363 #undef SOME_CASE 00364 } 00365 H5Gclose(param_id); 00366 } 00367 00368 VIGRA_EXPORT void dt_import_HDF5(hid_t & group_id, 00369 detail::DecisionTree & tree, 00370 std::string name); 00371 00372 00373 VIGRA_EXPORT void dt_export_HDF5(hid_t & group_id, 00374 detail::DecisionTree const & tree, 00375 std::string name); 00376 00377 } //namespace detail 00378 00379 template<class T> 00380 bool rf_export_HDF5(RandomForest<T> const &rf, 00381 std::string filename, 00382 std::string pathname = "", 00383 bool overwriteflag = false) 00384 { 00385 using detail::make_padded_number; 00386 using detail::options_export_HDF5; 00387 using detail::problemspec_export_HDF5; 00388 using detail::dt_export_HDF5; 00389 00390 hid_t file_id; 00391 //if file exists load it. 00392 FILE* pFile = std::fopen ( filename.c_str(), "r" ); 00393 if ( pFile != NULL) 00394 { 00395 std::fclose(pFile); 00396 file_id = H5Fopen(filename.c_str(), H5F_ACC_RDWR, 00397 H5P_DEFAULT); 00398 } 00399 else 00400 { 00401 //create a new file. 00402 file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, 00403 H5P_DEFAULT, 00404 H5P_DEFAULT); 00405 } 00406 vigra_postcondition(file_id >= 0, 00407 "rf_export_HDF5(): Unable to open file."); 00408 //std::cerr << pathname.c_str() 00409 00410 //if the group already exists this will cause an error 00411 //we will have to use the overwriteflag to check for 00412 //this, but i dont know how to delete groups... 00413 00414 hid_t group_id = pathname== "" ? 00415 file_id 00416 : H5Gcreate(file_id, pathname.c_str(), 00417 H5P_DEFAULT, 00418 H5P_DEFAULT, 00419 H5P_DEFAULT); 00420 00421 vigra_postcondition(group_id >= 0, 00422 "rf_export_HDF5(): Unable to create group"); 00423 00424 //save serialized options 00425 options_export_HDF5(group_id, rf.options(), "_options"); 00426 //save external parameters 00427 problemspec_export_HDF5(group_id, rf.ext_param(), "_ext_param"); 00428 //save trees 00429 00430 int tree_count = rf.options_.tree_count_; 00431 for(int ii = 0; ii < tree_count; ++ii) 00432 { 00433 std::string treename = "Tree_" + 00434 make_padded_number(ii, tree_count -1); 00435 dt_export_HDF5(group_id, rf.tree(ii), treename); 00436 } 00437 00438 //clean up the mess 00439 if(pathname != "") 00440 H5Gclose(group_id); 00441 H5Fclose(file_id); 00442 00443 return 1; 00444 } 00445 00446 00447 template<class T> 00448 bool rf_import_HDF5(RandomForest<T> &rf, 00449 std::string filename, 00450 std::string pathname = "") 00451 { 00452 using detail::find_groups_hdf5; 00453 using detail::options_import_HDF5; 00454 using detail::problemspec_import_HDF5; 00455 using detail::dt_export_HDF5; 00456 // check if file exists 00457 FILE* pFile = std::fopen ( filename.c_str(), "r" ); 00458 if ( pFile == NULL) 00459 return 0; 00460 std::fclose(pFile); 00461 //open file 00462 hid_t file_id = H5Fopen (filename.c_str(), 00463 H5F_ACC_RDONLY, 00464 H5P_DEFAULT); 00465 00466 vigra_postcondition(file_id >= 0, 00467 "rf_import_HDF5(): Unable to open file."); 00468 hid_t group_id = pathname== "" ? 00469 file_id 00470 : H5Gopen (file_id, 00471 pathname.c_str(), 00472 H5P_DEFAULT); 00473 00474 vigra_postcondition(group_id >= 0, 00475 "rf_export_HDF5(): Unable to create group"); 00476 00477 //get serialized options 00478 options_import_HDF5(group_id, rf.options_, "_options"); 00479 //save external parameters 00480 problemspec_import_HDF5(group_id, rf.ext_param_, "_ext_param"); 00481 // TREE SAVING TIME 00482 // get all groups in base path 00483 00484 std::set<std::string> tree_set; 00485 std::set<std::string>::iterator iter; 00486 find_groups_hdf5(filename, pathname, tree_set); 00487 00488 for(iter = tree_set.begin(); iter != tree_set.end(); ++iter) 00489 { 00490 if((*iter)[0] != '_') 00491 { 00492 rf.trees_.push_back(detail::DecisionTree(rf.ext_param_)); 00493 dt_import_HDF5(group_id, rf.trees_.back(), *iter); 00494 } 00495 } 00496 00497 //clean up the mess 00498 if(pathname != "") 00499 H5Gclose(group_id); 00500 H5Fclose(file_id); 00501 /*rf.tree_indices_.resize(rf.tree_count()); 00502 for(int ii = 0; ii < rf.tree_count(); ++ii) 00503 rf.tree_indices_[ii] = ii; */ 00504 return 1; 00505 } 00506 } // namespace vigra 00507 00508 #endif // HasHDF5 00509 00510 #endif // VIGRA_RANDOM_FOREST_HDF5_IMPEX_HXX 00511
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|