38 #ifndef PCL_ML_REGRESSION_VARIANCE_STATS_ESTIMATOR_H_
39 #define PCL_ML_REGRESSION_VARIANCE_STATS_ESTIMATOR_H_
42 #include <pcl/ml/stats_estimator.h>
43 #include <pcl/ml/branch_estimator.h>
52 template <
class FeatureType,
class LabelType>
67 feature.serialize (stream);
69 stream.write (
reinterpret_cast<const char*
> (&threshold),
sizeof (threshold));
71 stream.write (
reinterpret_cast<const char*
> (&value),
sizeof (value));
72 stream.write (
reinterpret_cast<const char*
> (&variance),
sizeof (variance));
74 const int num_of_sub_nodes =
static_cast<int> (sub_nodes.size ());
75 stream.write (
reinterpret_cast<const char*
> (&num_of_sub_nodes),
sizeof (num_of_sub_nodes));
76 for (
int sub_node_index = 0; sub_node_index < num_of_sub_nodes; ++sub_node_index)
78 sub_nodes[sub_node_index].serialize (stream);
88 feature.deserialize (stream);
90 stream.read (
reinterpret_cast<char*
> (&threshold),
sizeof (threshold));
92 stream.read (
reinterpret_cast<char*
> (&value),
sizeof (value));
93 stream.read (
reinterpret_cast<char*
> (&variance),
sizeof (variance));
96 stream.read (
reinterpret_cast<char*
> (&num_of_sub_nodes),
sizeof (num_of_sub_nodes));
97 sub_nodes.resize (num_of_sub_nodes);
99 if (num_of_sub_nodes > 0)
101 for (
int sub_node_index = 0; sub_node_index < num_of_sub_nodes; ++sub_node_index)
103 sub_nodes[sub_node_index].deserialize (stream);
124 template <
class LabelDataType,
class NodeType,
class DataSet,
class ExampleIndex>
132 : branch_estimator_ (branch_estimator)
142 return branch_estimator_->getNumOfBranches ();
150 NodeType & node)
const
166 std::vector<ExampleIndex> & examples,
167 std::vector<LabelDataType> & label_data,
168 std::vector<float> & results,
169 std::vector<unsigned char> & flags,
170 const float threshold)
const
172 const size_t num_of_examples = examples.size ();
173 const size_t num_of_branches = getNumOfBranches();
176 std::vector<LabelDataType> sums (num_of_branches+1, 0);
177 std::vector<LabelDataType> sqr_sums (num_of_branches+1, 0);
178 std::vector<size_t> branch_element_count (num_of_branches+1, 0);
180 for (
size_t branch_index = 0; branch_index < num_of_branches; ++branch_index)
182 branch_element_count[branch_index] = 1;
183 ++branch_element_count[num_of_branches];
186 for (
size_t example_index = 0; example_index < num_of_examples; ++example_index)
188 unsigned char branch_index;
189 computeBranchIndex (results[example_index], flags[example_index], threshold, branch_index);
191 LabelDataType label = label_data[example_index];
193 sums[branch_index] += label;
194 sums[num_of_branches] += label;
196 sqr_sums[branch_index] += label*label;
197 sqr_sums[num_of_branches] += label*label;
199 ++branch_element_count[branch_index];
200 ++branch_element_count[num_of_branches];
203 std::vector<float> variances (num_of_branches+1, 0);
204 for (
size_t branch_index = 0; branch_index < num_of_branches+1; ++branch_index)
206 const float mean_sum =
static_cast<float>(sums[branch_index]) / branch_element_count[branch_index];
207 const float mean_sqr_sum =
static_cast<float>(sqr_sums[branch_index]) / branch_element_count[branch_index];
208 variances[branch_index] = mean_sqr_sum - mean_sum*mean_sum;
211 float information_gain = variances[num_of_branches];
212 for (
size_t branch_index = 0; branch_index < num_of_branches; ++branch_index)
215 const float weight =
static_cast<float>(branch_element_count[branch_index]) /
static_cast<float>(branch_element_count[num_of_branches]);
216 information_gain -= weight*variances[branch_index];
219 return information_gain;
230 std::vector<float> & results,
231 std::vector<unsigned char> & flags,
232 const float threshold,
233 std::vector<unsigned char> & branch_indices)
const
235 const size_t num_of_results = results.size ();
236 const size_t num_of_branches = getNumOfBranches();
238 branch_indices.resize (num_of_results);
239 for (
size_t result_index = 0; result_index < num_of_results; ++result_index)
241 unsigned char branch_index;
242 computeBranchIndex (results[result_index], flags[result_index], threshold, branch_index);
243 branch_indices[result_index] = branch_index;
256 const unsigned char flag,
257 const float threshold,
258 unsigned char & branch_index)
const
260 branch_estimator_->computeBranchIndex (result, flag, threshold, branch_index);
273 std::vector<ExampleIndex> & examples,
274 std::vector<LabelDataType> & label_data,
275 NodeType & node)
const
277 const size_t num_of_examples = examples.size ();
279 LabelDataType sum = 0.0f;
280 LabelDataType sqr_sum = 0.0f;
281 for (
size_t example_index = 0; example_index < num_of_examples; ++example_index)
283 const LabelDataType label = label_data[example_index];
286 sqr_sum += label*label;
289 sum /= num_of_examples;
290 sqr_sum /= num_of_examples;
292 const float variance = sqr_sum - sum*sum;
295 node.variance = variance;
305 std::ostream & stream)
const
307 stream <<
"ERROR: RegressionVarianceStatsEstimator does not implement generateCodeForBranchIndex(...)";
317 std::ostream & stream)
const
319 stream <<
"ERROR: RegressionVarianceStatsEstimator does not implement generateCodeForBranchIndex(...)";