dtree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DET_DTREE_HPP
14 #define MLPACK_METHODS_DET_DTREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace det {
20 
44 template<typename MatType = arma::mat,
45  typename TagType = int>
46 class DTree
47 {
48  public:
50  typedef typename MatType::elem_type ElemType;
52  typedef typename MatType::vec_type VecType;
54  typedef typename arma::Col<ElemType> StatType;
55 
59  DTree();
60 
66  DTree(const DTree& obj);
67 
73  DTree& operator=(const DTree& obj);
74 
80  DTree(DTree&& obj);
81 
87  DTree& operator=(DTree&& obj);
88 
97  DTree(const StatType& maxVals,
98  const StatType& minVals,
99  const size_t totalPoints);
100 
109  DTree(MatType& data);
110 
123  DTree(const StatType& maxVals,
124  const StatType& minVals,
125  const size_t start,
126  const size_t end,
127  const double logNegError);
128 
141  DTree(const StatType& maxVals,
142  const StatType& minVals,
143  const size_t totalPoints,
144  const size_t start,
145  const size_t end);
146 
148  ~DTree();
149 
160  double Grow(MatType& data,
161  arma::Col<size_t>& oldFromNew,
162  const bool useVolReg = false,
163  const size_t maxLeafSize = 10,
164  const size_t minLeafSize = 5);
165 
174  double PruneAndUpdate(const double oldAlpha,
175  const size_t points,
176  const bool useVolReg = false);
177 
183  double ComputeValue(const VecType& query) const;
184 
194  TagType TagTree(const TagType& tag = 0, bool everyNode = false);
195 
196 
203  TagType FindBucket(const VecType& query) const;
204 
205 
211  void ComputeVariableImportance(arma::vec& importances) const;
212 
219  double LogNegativeError(const size_t totalPoints) const;
220 
224  bool WithinRange(const VecType& query) const;
225 
226  private:
227  // The indices in the complete set of points
228  // (after all forms of swapping in the original data
229  // matrix to align all the points in a node
230  // consecutively in the matrix. The 'old_from_new' array
231  // maps the points back to their original indices.
232 
235  size_t start;
238  size_t end;
239 
241  StatType maxVals;
243  StatType minVals;
244 
246  size_t splitDim;
247 
249  ElemType splitValue;
250 
252  double logNegError;
253 
255  double subtreeLeavesLogNegError;
256 
258  size_t subtreeLeaves;
259 
261  bool root;
262 
264  double ratio;
265 
267  double logVolume;
268 
270  TagType bucketTag;
271 
273  double alphaUpper;
274 
276  DTree* left;
278  DTree* right;
279 
280  public:
282  size_t Start() const { return start; }
284  size_t End() const { return end; }
286  size_t SplitDim() const { return splitDim; }
288  ElemType SplitValue() const { return splitValue; }
290  double LogNegError() const { return logNegError; }
292  double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; }
294  size_t SubtreeLeaves() const { return subtreeLeaves; }
297  double Ratio() const { return ratio; }
299  double LogVolume() const { return logVolume; }
301  DTree* Left() const { return left; }
303  DTree* Right() const { return right; }
305  bool Root() const { return root; }
307  double AlphaUpper() const { return alphaUpper; }
309  TagType BucketTag() const { return bucketTag; }
311  size_t NumChildren() const { return !left ? 0 : 2; }
312 
319  DTree& Child(const size_t child) const { return !child ? *left : *right; }
320 
321  DTree*& ChildPtr(const size_t child) { return (!child) ? left : right; }
322 
324  const StatType& MaxVals() const { return maxVals; }
325 
327  const StatType& MinVals() const { return minVals; }
328 
332  template<typename Archive>
333  void serialize(Archive& ar, const uint32_t /* version */);
334 
335  private:
336  // Utility methods.
337 
341  bool FindSplit(const MatType& data,
342  size_t& splitDim,
343  ElemType& splitValue,
344  double& leftError,
345  double& rightError,
346  const size_t minLeafSize = 5) const;
347 
351  size_t SplitData(MatType& data,
352  const size_t splitDim,
353  const ElemType splitValue,
354  arma::Col<size_t>& oldFromNew) const;
355 
356  void FillMinMax(const StatType& mins,
357  const StatType& maxs);
358 };
359 
360 } // namespace det
361 } // namespace mlpack
362 
363 #include "dtree_impl.hpp"
364 
365 #endif // MLPACK_METHODS_DET_DTREE_HPP
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
Definition: dtree.hpp:294
double ComputeValue(const VecType &query) const
Compute the logarithm of the density estimate of a given query point.
DTree & operator=(const DTree &obj)
Copy the given tree.
size_t SplitDim() const
Return the split dimension of this node.
Definition: dtree.hpp:286
double Grow(MatType &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
~DTree()
Clean up memory allocated by the tree.
size_t Start() const
Return the starting index of points contained in this node.
Definition: dtree.hpp:282
Linear algebra utility functions, generally performed on matrices or vectors.
MatType::elem_type ElemType
The actual, underlying type we&#39;re working with.
Definition: dtree.hpp:50
void serialize(Archive &ar, const uint32_t)
Serialize the density estimation tree.
bool WithinRange(const VecType &query) const
Return whether a query point is within the range of this node.
arma::Col< ElemType > StatType
The statistic type we are holding.
Definition: dtree.hpp:54
double LogNegError() const
Return the log negative error of this node.
Definition: dtree.hpp:290
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Ratio() const
Return the ratio of points in this node to the points in the whole dataset.
Definition: dtree.hpp:297
DTree * Left() const
Return the left child.
Definition: dtree.hpp:301
DTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
Definition: dtree.hpp:319
size_t NumChildren() const
Return the number of children in this node.
Definition: dtree.hpp:311
size_t End() const
Return the first index of a point not contained in this node.
Definition: dtree.hpp:284
TagType TagTree(const TagType &tag=0, bool everyNode=false)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
ElemType SplitValue() const
Return the split value of this node.
Definition: dtree.hpp:288
const StatType & MaxVals() const
Return the maximum values.
Definition: dtree.hpp:324
DTree * Right() const
Return the right child.
Definition: dtree.hpp:303
double LogVolume() const
Return the inverse of the volume of this node.
Definition: dtree.hpp:299
TagType FindBucket(const VecType &query) const
Return the tag of the leaf containing the query.
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
Definition: dtree.hpp:292
MatType::vec_type VecType
The type of vector we are using.
Definition: dtree.hpp:52
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
Definition: dtree.hpp:46
double LogNegativeError(const size_t totalPoints) const
Compute the log-negative-error for this point, given the total number of points in the dataset...
const StatType & MinVals() const
Return the minimum values.
Definition: dtree.hpp:327
double AlphaUpper() const
Return the upper part of the alpha sum.
Definition: dtree.hpp:307
TagType BucketTag() const
Return the current bucket&#39;s ID, if leaf, or -1 otherwise.
Definition: dtree.hpp:309
bool Root() const
Return whether or not this is the root of the tree.
Definition: dtree.hpp:305
DTree *& ChildPtr(const size_t child)
Definition: dtree.hpp:321
DTree()
Create an empty density estimation tree.