Point Cloud Library (PCL) 1.12.0
Loading...
Searching...
No Matches
fern_trainer.h
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40#include <pcl/common/common.h>
41#include <pcl/ml/feature_handler.h>
42#include <pcl/ml/ferns/fern.h>
43#include <pcl/ml/stats_estimator.h>
44
45#include <vector>
46
47namespace pcl {
48
49/** Trainer for a Fern. */
50template <class FeatureType,
51 class DataSet,
52 class LabelType,
53 class ExampleIndex,
54 class NodeType>
55class PCL_EXPORTS FernTrainer {
56
57public:
58 /** Constructor. */
60
61 /** Destructor. */
62 virtual ~FernTrainer();
63
64 /** Sets the feature handler used to create and evaluate features.
65 *
66 * \param[in] feature_handler the feature handler
67 */
68 inline void
71 {
72 feature_handler_ = &feature_handler;
73 }
74
75 /** Sets the object for estimating the statistics for tree nodes.
76 *
77 * \param[in] stats_estimator the statistics estimator
78 */
79 inline void
82 {
83 stats_estimator_ = &stats_estimator;
84 }
85
86 /** Sets the maximum depth of the learned tree.
87 *
88 * \param[in] fern_depth maximum depth of the learned tree
89 */
90 inline void
91 setFernDepth(const std::size_t fern_depth)
92 {
93 fern_depth_ = fern_depth;
94 }
95
96 /** Sets the number of features used to find optimal decision features.
97 *
98 * \param[in] num_of_features the number of features
99 */
100 inline void
101 setNumOfFeatures(const std::size_t num_of_features)
102 {
103 num_of_features_ = num_of_features;
104 }
105
106 /** Sets the number of thresholds tested for finding the optimal decision
107 * threshold on the feature responses.
108 *
109 * \param[in] num_of_threshold the number of thresholds
110 */
111 inline void
112 setNumOfThresholds(const std::size_t num_of_threshold)
113 {
114 num_of_thresholds_ = num_of_threshold;
115 }
116
117 /** Sets the input data set used for training.
118 *
119 * \param[in] data_set the data set used for training
120 */
121 inline void
122 setTrainingDataSet(DataSet& data_set)
123 {
124 data_set_ = data_set;
125 }
126
127 /** Example indices that specify the data used for training.
128 *
129 * \param[in] examples the examples
130 */
131 inline void
132 setExamples(std::vector<ExampleIndex>& examples)
133 {
134 examples_ = examples;
135 }
136
137 /** Sets the label data corresponding to the example data.
138 *
139 * \param[in] label_data the label data
140 */
141 inline void
142 setLabelData(std::vector<LabelType>& label_data)
143 {
144 label_data_ = label_data;
145 }
146
147 /** Trains a decision tree using the set training data and settings.
148 *
149 * \param[out] fern destination for the trained tree
150 */
151 void
152 train(Fern<FeatureType, NodeType>& fern);
153
154protected:
155 /** Creates uniformely distrebuted thresholds over the range of the supplied
156 * values.
157 *
158 * \param[in] num_of_thresholds the number of thresholds to create
159 * \param[in] values the values for estimating the expected value range
160 * \param[out] thresholds the resulting thresholds
161 */
162 static void
163 createThresholdsUniform(const std::size_t num_of_thresholds,
164 std::vector<float>& values,
165 std::vector<float>& thresholds);
166
167private:
168 /** Desired depth of the learned fern. */
169 std::size_t fern_depth_;
170 /** Number of features used to find optimal decision features. */
171 std::size_t num_of_features_;
172 /** Number of thresholds. */
173 std::size_t num_of_thresholds_;
174
175 /** FeatureHandler instance, responsible for creating and evaluating features. */
177 /** StatsEstimator instance, responsible for gathering stats about a node. */
179
180 /** The training data set. */
181 DataSet data_set_;
182 /** The label data. */
183 std::vector<LabelType> label_data_;
184 /** The example data. */
185 std::vector<ExampleIndex> examples_;
186};
187
188} // namespace pcl
189
190#include <pcl/ml/impl/ferns/fern_trainer.hpp>
Utility class interface which is used for creating and evaluating features.
Class representing a Fern.
Definition fern.h:49
Trainer for a Fern.
void setTrainingDataSet(DataSet &data_set)
Sets the input data set used for training.
void setFernDepth(const std::size_t fern_depth)
Sets the maximum depth of the learned tree.
void setNumOfFeatures(const std::size_t num_of_features)
Sets the number of features used to find optimal decision features.
void setStatsEstimator(pcl::StatsEstimator< LabelType, NodeType, DataSet, ExampleIndex > &stats_estimator)
Sets the object for estimating the statistics for tree nodes.
void setNumOfThresholds(const std::size_t num_of_threshold)
Sets the number of thresholds tested for finding the optimal decision threshold on the feature respon...
void setFeatureHandler(pcl::FeatureHandler< FeatureType, DataSet, ExampleIndex > &feature_handler)
Sets the feature handler used to create and evaluate features.
void setLabelData(std::vector< LabelType > &label_data)
Sets the label data corresponding to the example data.
void setExamples(std::vector< ExampleIndex > &examples)
Example indices that specify the data used for training.
Class interface for gathering statistics for decision tree learning.
Define standard C methods and C++ classes that are common to all methods.