random_forest_hdf5_impex.hxx
|
 |
36 #ifndef VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX
37 #define VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX
40 #include "random_forest.hxx"
41 #include "hdf5impex.hxx"
47 static const char *
const rf_hdf5_options =
"_options";
48 static const char *
const rf_hdf5_ext_param =
"_ext_param";
49 static const char *
const rf_hdf5_labels =
"labels";
50 static const char *
const rf_hdf5_topology =
"topology";
51 static const char *
const rf_hdf5_parameters =
"parameters";
52 static const char *
const rf_hdf5_tree =
"Tree_";
53 static const char *
const rf_hdf5_version_group =
".";
54 static const char *
const rf_hdf5_version_tag =
"vigra_random_forest_version";
55 static const double rf_hdf5_version = 0.1;
60 VIGRA_EXPORT
void options_import_HDF5(HDF5File &, RandomForestOptions &,
63 VIGRA_EXPORT
void options_export_HDF5(HDF5File &,
const RandomForestOptions &,
66 VIGRA_EXPORT
void dt_import_HDF5(HDF5File &, detail::DecisionTree &,
69 VIGRA_EXPORT
void dt_export_HDF5(HDF5File &,
const detail::DecisionTree &,
73 void rf_import_HDF5_to_map(HDF5File & h5context, X & param,
74 const char *
const ignored_label = 0)
77 typedef typename X::map_type map_type;
78 typedef std::pair<typename map_type::iterator, bool> inserter_type;
79 typedef typename map_type::value_type value_type;
80 typedef typename map_type::mapped_type mapped_type;
82 map_type serialized_param;
83 bool ignored_seen = ignored_label == 0;
85 std::vector<std::string> names = h5context.ls();
86 std::vector<std::string>::const_iterator j;
87 for (j = names.begin(); j != names.end(); ++j)
89 if (ignored_label && *j == ignored_label)
95 inserter_type new_array
96 = serialized_param.insert(value_type(*j, mapped_type()));
98 h5context.readAndResize(*j, (*(new_array.first)).second);
100 vigra_precondition(ignored_seen,
"rf_import_HDF5_to_map(): "
101 "labels are missing.");
102 param.make_from_map(serialized_param);
106 void problemspec_import_HDF5(HDF5File & h5context, ProblemSpec<T> & param,
107 const std::string & name)
110 rf_import_HDF5_to_map(h5context, param, rf_hdf5_labels);
112 ArrayVector<T> labels;
113 h5context.readAndResize(rf_hdf5_labels, labels);
114 param.classes_(labels.begin(), labels.end());
119 void rf_export_map_to_HDF5(HDF5File & h5context,
const X & param)
121 typedef typename X::map_type map_type;
122 map_type serialized_param;
124 param.make_map(serialized_param);
125 typename map_type::const_iterator j;
126 for (j = serialized_param.begin(); j != serialized_param.end(); ++j)
127 h5context.write(j->first, j->second);
131 void problemspec_export_HDF5(HDF5File & h5context, ProblemSpec<T>
const & param,
132 const std::string & name)
134 h5context.cd_mk(name);
135 rf_export_map_to_HDF5(h5context, param);
136 h5context.write(rf_hdf5_labels, param.classes);
140 struct padded_number_string_data;
141 class VIGRA_EXPORT padded_number_string
144 padded_number_string_data* padded_number;
146 padded_number_string(
const padded_number_string &);
147 void operator=(
const padded_number_string &);
149 padded_number_string(
int n);
150 std::string operator()(
int k)
const;
151 ~padded_number_string();
154 inline std::string get_cwd(HDF5File & h5context)
156 return h5context.get_absolute_path(h5context.pwd());
175 template<
class T,
class Tag>
178 const std::string & pathname =
"")
181 if (pathname.size()) {
182 cwd = detail::get_cwd(h5context);
183 h5context.
cd_mk(pathname);
186 h5context.
writeAttribute(rf_hdf5_version_group, rf_hdf5_version_tag,
189 detail::options_export_HDF5(h5context, rf.
options(), rf_hdf5_options);
191 detail::problemspec_export_HDF5(h5context, rf.
ext_param(),
194 int tree_count = rf.options_.tree_count_;
195 detail::padded_number_string tree_number(tree_count);
196 for (
int i = 0; i < tree_count; ++i)
197 detail::dt_export_HDF5(h5context, rf.
tree(i),
198 rf_hdf5_tree + tree_number(i));
218 template<
class T,
class Tag>
220 const std::string & filename,
221 const std::string & pathname =
"")
223 HDF5File h5context(filename , HDF5File::Open);
244 template<
class T,
class Tag>
247 const std::string & pathname =
"")
250 HDF5File h5context(fileHandle, pathname);
267 template<
class T,
class Tag>
270 const std::string & pathname =
"")
273 if (pathname.size()) {
274 cwd = detail::get_cwd(h5context);
275 h5context.
cd(pathname);
278 if (h5context.
existsAttribute(rf_hdf5_version_group, rf_hdf5_version_tag))
281 h5context.
readAttribute(rf_hdf5_version_group, rf_hdf5_version_tag,
283 vigra_precondition(read_version <= rf_hdf5_version,
284 "rf_import_HDF5(): unexpected file format version.");
287 detail::options_import_HDF5(h5context, rf.options_, rf_hdf5_options);
289 detail::problemspec_import_HDF5(h5context, rf.ext_param_,
294 std::vector<std::string> names = h5context.
ls();
295 std::vector<std::string>::const_iterator j;
296 for (j = names.begin(); j != names.end(); ++j)
298 if ((*j->rbegin() ==
'/') && (*j->begin() !=
'_'))
300 rf.trees_.push_back(detail::DecisionTree(rf.ext_param_));
301 detail::dt_import_HDF5(h5context, rf.trees_.
back(), *j);
322 template<
class T,
class Tag>
324 const std::string & filename,
325 const std::string & pathname =
"")
327 HDF5File h5context(filename, HDF5File::OpenReadOnly);
348 template<
class T,
class Tag>
351 const std::string & pathname =
"")
354 HDF5File h5context(fileHandle, pathname,
true);
360 #endif // VIGRA_RANDOM_FOREST_HDF5_IMPEX_HXX