binary_space_tree.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
12 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
13 
14 #include <mlpack/prereqs.hpp>
15 
16 #include "../statistic.hpp"
17 #include "midpoint_split.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
47 template<typename MetricType,
48  typename StatisticType = EmptyStatistic,
49  typename MatType = arma::mat,
50  template<typename BoundMetricType, typename...> class BoundType =
52  template<typename SplitBoundType, typename SplitMatType>
53  class SplitType = MidpointSplit>
55 {
56  public:
58  typedef MatType Mat;
60  typedef typename MatType::elem_type ElemType;
61 
62  typedef SplitType<BoundType<MetricType>, MatType> Split;
63 
64  private:
66  BinarySpaceTree* left;
68  BinarySpaceTree* right;
70  BinarySpaceTree* parent;
73  size_t begin;
76  size_t count;
78  BoundType<MetricType> bound;
80  StatisticType stat;
82  ElemType parentDistance;
85  ElemType furthestDescendantDistance;
87  ElemType minimumBoundDistance;
90  MatType* dataset;
91 
92  public:
95  template<typename RuleType>
97 
99  template<typename RuleType>
101 
102  template<typename RuleType>
104 
113  BinarySpaceTree(const MatType& data, const size_t maxLeafSize = 20);
114 
127  BinarySpaceTree(const MatType& data,
128  std::vector<size_t>& oldFromNew,
129  const size_t maxLeafSize = 20);
130 
146  BinarySpaceTree(const MatType& data,
147  std::vector<size_t>& oldFromNew,
148  std::vector<size_t>& newFromOld,
149  const size_t maxLeafSize = 20);
150 
160  BinarySpaceTree(MatType&& data,
161  const size_t maxLeafSize = 20);
162 
175  BinarySpaceTree(MatType&& data,
176  std::vector<size_t>& oldFromNew,
177  const size_t maxLeafSize = 20);
178 
194  BinarySpaceTree(MatType&& data,
195  std::vector<size_t>& oldFromNew,
196  std::vector<size_t>& newFromOld,
197  const size_t maxLeafSize = 20);
198 
212  const size_t begin,
213  const size_t count,
214  SplitType<BoundType<MetricType>, MatType>& splitter,
215  const size_t maxLeafSize = 20);
216 
237  const size_t begin,
238  const size_t count,
239  std::vector<size_t>& oldFromNew,
240  SplitType<BoundType<MetricType>, MatType>& splitter,
241  const size_t maxLeafSize = 20);
242 
266  const size_t begin,
267  const size_t count,
268  std::vector<size_t>& oldFromNew,
269  std::vector<size_t>& newFromOld,
270  SplitType<BoundType<MetricType>, MatType>& splitter,
271  const size_t maxLeafSize = 20);
272 
279  BinarySpaceTree(const BinarySpaceTree& other);
280 
286 
293 
300 
306  template<typename Archive>
308  Archive& ar,
309  const typename std::enable_if_t<cereal::is_loading<Archive>()>* = 0);
310 
317 
319  const BoundType<MetricType>& Bound() const { return bound; }
321  BoundType<MetricType>& Bound() { return bound; }
322 
324  const StatisticType& Stat() const { return stat; }
326  StatisticType& Stat() { return stat; }
327 
329  bool IsLeaf() const;
330 
332  BinarySpaceTree* Left() const { return left; }
334  BinarySpaceTree*& Left() { return left; }
335 
337  BinarySpaceTree* Right() const { return right; }
339  BinarySpaceTree*& Right() { return right; }
340 
342  BinarySpaceTree* Parent() const { return parent; }
344  BinarySpaceTree*& Parent() { return parent; }
345 
347  const MatType& Dataset() const { return *dataset; }
349  MatType& Dataset() { return *dataset; }
350 
352  MetricType Metric() const { return MetricType(); }
353 
355  size_t NumChildren() const;
356 
361  template<typename VecType>
362  size_t GetNearestChild(
363  const VecType& point,
365 
370  template<typename VecType>
371  size_t GetFurthestChild(
372  const VecType& point,
374 
379  size_t GetNearestChild(const BinarySpaceTree& queryNode);
380 
385  size_t GetFurthestChild(const BinarySpaceTree& queryNode);
386 
391  ElemType FurthestPointDistance() const;
392 
400  ElemType FurthestDescendantDistance() const;
401 
403  ElemType MinimumBoundDistance() const;
404 
407  ElemType ParentDistance() const { return parentDistance; }
410  ElemType& ParentDistance() { return parentDistance; }
411 
418  BinarySpaceTree& Child(const size_t child) const;
419 
420  BinarySpaceTree*& ChildPtr(const size_t child)
421  { return (child == 0) ? left : right; }
422 
424  size_t NumPoints() const;
425 
431  size_t NumDescendants() const;
432 
440  size_t Descendant(const size_t index) const;
441 
450  size_t Point(const size_t index) const;
451 
453  ElemType MinDistance(const BinarySpaceTree& other) const
454  {
455  return bound.MinDistance(other.Bound());
456  }
457 
459  ElemType MaxDistance(const BinarySpaceTree& other) const
460  {
461  return bound.MaxDistance(other.Bound());
462  }
463 
466  {
467  return bound.RangeDistance(other.Bound());
468  }
469 
471  template<typename VecType>
472  ElemType MinDistance(const VecType& point,
474  const
475  {
476  return bound.MinDistance(point);
477  }
478 
480  template<typename VecType>
481  ElemType MaxDistance(const VecType& point,
483  const
484  {
485  return bound.MaxDistance(point);
486  }
487 
489  template<typename VecType>
491  RangeDistance(const VecType& point,
492  typename std::enable_if_t<IsVector<VecType>::value>* = 0) const
493  {
494  return bound.RangeDistance(point);
495  }
496 
498  size_t Begin() const { return begin; }
500  size_t& Begin() { return begin; }
501 
503  size_t Count() const { return count; }
505  size_t& Count() { return count; }
506 
508  void Center(arma::vec& center) const { bound.Center(center); }
509 
510  private:
517  void SplitNode(const size_t maxLeafSize,
518  SplitType<BoundType<MetricType>, MatType>& splitter);
519 
528  void SplitNode(std::vector<size_t>& oldFromNew,
529  const size_t maxLeafSize,
530  SplitType<BoundType<MetricType>, MatType>& splitter);
531 
538  template<typename BoundType2>
539  void UpdateBound(BoundType2& boundToUpdate);
540 
547  void UpdateBound(bound::HollowBallBound<MetricType>& boundToUpdate);
548 
549  protected:
556  BinarySpaceTree();
557 
559  friend class cereal::access;
560 
561  public:
565  template<typename Archive>
566  void serialize(Archive& ar, const uint32_t version);
567 };
568 
569 } // namespace tree
570 } // namespace mlpack
571 
572 // Include implementation.
573 #include "binary_space_tree_impl.hpp"
574 
575 // Include everything else, if necessary.
576 #include "../binary_space_tree.hpp"
577 
578 #endif
~BinarySpaceTree()
Deletes this node, deallocating the memory for the children and calling their destructors in turn...
const MatType & Dataset() const
Get the dataset which the tree is built on.
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
BinarySpaceTree *& Left()
Modify the left child of this node.
void Center(arma::vec &center) const
Store the center of the bounding region in the given vector.
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
size_t Descendant(const size_t index) const
Return the index (with reference to the dataset) of a particular descendant of this node...
A dual-tree traverser for binary space trees; see dual_tree_traverser.hpp.
Linear algebra utility functions, generally performed on matrices or vectors.
BinarySpaceTree * Right() const
Gets the right child of this node.
BinarySpaceTree *& ChildPtr(const size_t child)
BinarySpaceTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
The core includes that mlpack expects; standard C++ includes and Armadillo.
BinarySpaceTree * Parent() const
Gets the parent of this node.
MatType Mat
So other classes can use TreeType::Mat.
ElemType MaxDistance(const BinarySpaceTree &other) const
Return the maximum distance to another node.
ElemType MinDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Return the minimum distance to another point.
BinarySpaceTree *& Parent()
Modify the parent of this node.
A binary space partitioning tree, such as a KD-tree or a ball tree.
ElemType MaxDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Return the maximum distance to another point.
A binary space partitioning tree node is split into its left and right child.
BinarySpaceTree()
A default constructor.
size_t NumDescendants() const
Return the number of descendants of this node.
BoundType< MetricType > & Bound()
Return the bound object for this node.
size_t NumChildren() const
Return the number of children in this node.
MetricType Metric() const
Get the metric that the tree uses.
BinarySpaceTree *& Right()
Modify the right child of this node.
Hyper-rectangle bound for an L-metric.
Definition: hrectbound.hpp:54
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
size_t & Begin()
Modify the index of the beginning point of this subset.
MatType::elem_type ElemType
The type of element held in MatType.
MatType & Dataset()
Modify the dataset which the tree is built on. Be careful!
const BoundType< MetricType > & Bound() const
Return the bound object for this node.
ElemType MinDistance(const BinarySpaceTree &other) const
Return the minimum distance to another node.
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
A single-tree traverser for binary space trees; see single_tree_traverser.hpp for implementation...
BinarySpaceTree * Left() const
Gets the left child of this node.
SplitType< BoundType< MetricType >, MatType > Split
ElemType ParentDistance() const
Return the distance from the center of this node to the center of the parent node.
StatisticType & Stat()
Return the statistic object for this node.
math::RangeType< ElemType > RangeDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Return the minimum and maximum distance to another point.
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
size_t & Count()
Modify the number of points in this subset.
size_t Begin() const
Return the index of the beginning point of this subset.
BinarySpaceTree & operator=(const BinarySpaceTree &other)
Copy the given BinarySaceTree.
ElemType MinimumBoundDistance() const
Return the minimum distance from the center of the node to any bound edge.
ElemType & ParentDistance()
Modify the distance from the center of this node to the center of the parent node.
const StatisticType & Stat() const
Return the statistic object for this node.
Hollow ball bound encloses a set of points at a specific distance (radius) from a specific point (cen...
math::RangeType< ElemType > RangeDistance(const BinarySpaceTree &other) const
Return the minimum and maximum distance to another node.
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
ElemType FurthestPointDistance() const
Return the furthest distance to a point held in this node.
size_t Count() const
Return the number of points in this subset.
void serialize(Archive &ar, const uint32_t version)
Serialize the tree.
Empty statistic if you are not interested in storing statistics in your tree.
Definition: statistic.hpp:24
ElemType FurthestDescendantDistance() const
Return the furthest possible descendant distance.