cover_tree.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
13 #define MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
14 
15 #include <mlpack/prereqs.hpp>
17 
18 #include "../statistic.hpp"
19 #include "first_point_is_root.hpp"
20 
21 namespace mlpack {
22 namespace tree {
23 
95 template<typename MetricType = metric::LMetric<2, true>,
96  typename StatisticType = EmptyStatistic,
97  typename MatType = arma::mat,
98  typename RootPointPolicy = FirstPointIsRoot>
99 class CoverTree
100 {
101  public:
103  typedef MatType Mat;
105  typedef typename MatType::elem_type ElemType;
106 
118  CoverTree(const MatType& dataset,
119  const ElemType base = 2.0,
120  MetricType* metric = NULL);
121 
131  CoverTree(const MatType& dataset,
132  MetricType& metric,
133  const ElemType base = 2.0);
134 
142  CoverTree(MatType&& dataset,
143  const ElemType base = 2.0);
144 
153  CoverTree(MatType&& dataset,
154  MetricType& metric,
155  const ElemType base = 2.0);
156 
189  CoverTree(const MatType& dataset,
190  const ElemType base,
191  const size_t pointIndex,
192  const int scale,
193  CoverTree* parent,
194  const ElemType parentDistance,
195  arma::Col<size_t>& indices,
196  arma::vec& distances,
197  size_t nearSetSize,
198  size_t& farSetSize,
199  size_t& usedSetSize,
200  MetricType& metric = NULL);
201 
218  CoverTree(const MatType& dataset,
219  const ElemType base,
220  const size_t pointIndex,
221  const int scale,
222  CoverTree* parent,
223  const ElemType parentDistance,
224  const ElemType furthestDescendantDistance,
225  MetricType* metric = NULL);
226 
233  CoverTree(const CoverTree& other);
234 
241  CoverTree(CoverTree&& other);
242 
248  CoverTree& operator=(const CoverTree& other);
249 
255  CoverTree& operator=(CoverTree&& other);
256 
260  template<typename Archive>
261  CoverTree(
262  Archive& ar,
263  const typename std::enable_if_t<cereal::is_loading<Archive>()>* = 0);
264 
268  ~CoverTree();
269 
272  template<typename RuleType>
274 
276  template<typename RuleType>
278 
279  template<typename RuleType>
281 
283  const MatType& Dataset() const { return *dataset; }
284 
286  size_t Point() const { return point; }
288  size_t Point(const size_t) const { return point; }
289 
290  bool IsLeaf() const { return (children.size() == 0); }
291  size_t NumPoints() const { return 1; }
292 
294  const CoverTree& Child(const size_t index) const { return *children[index]; }
296  CoverTree& Child(const size_t index) { return *children[index]; }
297 
298  CoverTree*& ChildPtr(const size_t index) { return children[index]; }
299 
301  size_t NumChildren() const { return children.size(); }
302 
304  const std::vector<CoverTree*>& Children() const { return children; }
306  std::vector<CoverTree*>& Children() { return children; }
307 
309  size_t NumDescendants() const;
310 
312  size_t Descendant(const size_t index) const;
313 
315  int Scale() const { return scale; }
317  int& Scale() { return scale; }
318 
320  ElemType Base() const { return base; }
322  ElemType& Base() { return base; }
323 
325  const StatisticType& Stat() const { return stat; }
327  StatisticType& Stat() { return stat; }
328 
333  template<typename VecType>
334  size_t GetNearestChild(
335  const VecType& point,
337 
342  template<typename VecType>
343  size_t GetFurthestChild(
344  const VecType& point,
346 
351  size_t GetNearestChild(const CoverTree& queryNode);
352 
357  size_t GetFurthestChild(const CoverTree& queryNode);
358 
360  ElemType MinDistance(const CoverTree& other) const;
361 
364  ElemType MinDistance(const CoverTree& other, const ElemType distance) const;
365 
367  ElemType MinDistance(const arma::vec& other) const;
368 
371  ElemType MinDistance(const arma::vec& other, const ElemType distance) const;
372 
374  ElemType MaxDistance(const CoverTree& other) const;
375 
378  ElemType MaxDistance(const CoverTree& other, const ElemType distance) const;
379 
381  ElemType MaxDistance(const arma::vec& other) const;
382 
385  ElemType MaxDistance(const arma::vec& other, const ElemType distance) const;
386 
389 
393  const ElemType distance) const;
394 
396  math::RangeType<ElemType> RangeDistance(const arma::vec& other) const;
397 
400  math::RangeType<ElemType> RangeDistance(const arma::vec& other,
401  const ElemType distance) const;
402 
404  CoverTree* Parent() const { return parent; }
406  CoverTree*& Parent() { return parent; }
407 
409  ElemType ParentDistance() const { return parentDistance; }
411  ElemType& ParentDistance() { return parentDistance; }
412 
414  ElemType FurthestPointDistance() const { return 0.0; }
415 
417  ElemType FurthestDescendantDistance() const
418  { return furthestDescendantDistance; }
421  ElemType& FurthestDescendantDistance() { return furthestDescendantDistance; }
422 
425  ElemType MinimumBoundDistance() const { return furthestDescendantDistance; }
426 
428  void Center(arma::vec& center) const
429  {
430  center = arma::vec(dataset->col(point));
431  }
432 
434  MetricType& Metric() const { return *metric; }
435 
436  private:
438  const MatType* dataset;
440  size_t point;
442  std::vector<CoverTree*> children;
444  int scale;
446  ElemType base;
448  StatisticType stat;
450  size_t numDescendants;
452  CoverTree* parent;
454  ElemType parentDistance;
456  ElemType furthestDescendantDistance;
458  bool localMetric;
460  bool localDataset;
462  MetricType* metric;
463 
467  void CreateChildren(arma::Col<size_t>& indices,
468  arma::vec& distances,
469  size_t nearSetSize,
470  size_t& farSetSize,
471  size_t& usedSetSize);
472 
484  void ComputeDistances(const size_t pointIndex,
485  const arma::Col<size_t>& indices,
486  arma::vec& distances,
487  const size_t pointSetSize);
502  size_t SplitNearFar(arma::Col<size_t>& indices,
503  arma::vec& distances,
504  const ElemType bound,
505  const size_t pointSetSize);
506 
526  size_t SortPointSet(arma::Col<size_t>& indices,
527  arma::vec& distances,
528  const size_t childFarSetSize,
529  const size_t childUsedSetSize,
530  const size_t farSetSize);
531 
532  void MoveToUsedSet(arma::Col<size_t>& indices,
533  arma::vec& distances,
534  size_t& nearSetSize,
535  size_t& farSetSize,
536  size_t& usedSetSize,
537  arma::Col<size_t>& childIndices,
538  const size_t childFarSetSize,
539  const size_t childUsedSetSize);
540  size_t PruneFarSet(arma::Col<size_t>& indices,
541  arma::vec& distances,
542  const ElemType bound,
543  const size_t nearSetSize,
544  const size_t pointSetSize);
545 
550  void RemoveNewImplicitNodes();
551 
552  protected:
559  CoverTree();
560 
562  friend class cereal::access;
563 
564  public:
568  template<typename Archive>
569  void serialize(Archive& ar, const uint32_t /* version */);
570 
571  size_t DistanceComps() const { return distanceComps; }
572  size_t& DistanceComps() { return distanceComps; }
573 
574  private:
575  size_t distanceComps;
576 };
577 
578 } // namespace tree
579 } // namespace mlpack
580 
581 // Include implementation.
582 #include "cover_tree_impl.hpp"
583 
584 // Include the rest of the pieces, if necessary.
585 #include "../cover_tree.hpp"
586 
587 #endif
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
size_t DistanceComps() const
Definition: cover_tree.hpp:571
CoverTree & operator=(const CoverTree &other)
Copy the given Cover Tree.
size_t NumPoints() const
Definition: cover_tree.hpp:291
MatType Mat
So that other classes can access the matrix type.
Definition: cover_tree.hpp:103
void Center(arma::vec &center) const
Get the center of the node and store it in the given vector.
Definition: cover_tree.hpp:428
A dual-tree cover tree traverser; see dual_tree_traverser.hpp.
Definition: cover_tree.hpp:277
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
ElemType Base() const
Get the base.
Definition: cover_tree.hpp:320
Linear algebra utility functions, generally performed on matrices or vectors.
size_t Point() const
Get the index of the point which this node represents.
Definition: cover_tree.hpp:286
MatType::elem_type ElemType
The type held by the matrix type.
Definition: cover_tree.hpp:105
ElemType MaxDistance(const CoverTree &other) const
Return the maximum distance to another node.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
ElemType & FurthestDescendantDistance()
Modify the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:421
const std::vector< CoverTree * > & Children() const
Get the children.
Definition: cover_tree.hpp:304
int & Scale()
Modify the scale of this node. Be careful...
Definition: cover_tree.hpp:317
StatisticType & Stat()
Modify the statistic for this node.
Definition: cover_tree.hpp:327
CoverTree()
A default constructor.
CoverTree *& Parent()
Modify the parent node.
Definition: cover_tree.hpp:406
CoverTree * Parent() const
Get the parent node.
Definition: cover_tree.hpp:404
std::vector< CoverTree * > & Children()
Modify the children manually (maybe not a great idea).
Definition: cover_tree.hpp:306
int Scale() const
Get the scale of this node.
Definition: cover_tree.hpp:315
~CoverTree()
Delete this cover tree node and its children.
const StatisticType & Stat() const
Get the statistic for this node.
Definition: cover_tree.hpp:325
CoverTree *& ChildPtr(const size_t index)
Definition: cover_tree.hpp:298
A single-tree cover tree traverser; see single_tree_traverser.hpp for implementation.
Definition: cover_tree.hpp:273
ElemType ParentDistance() const
Get the distance to the parent.
Definition: cover_tree.hpp:409
ElemType MinDistance(const CoverTree &other) const
Return the minimum distance to another node.
size_t Point(const size_t) const
For compatibility with other trees; the argument is ignored.
Definition: cover_tree.hpp:288
const MatType & Dataset() const
Get a reference to the dataset.
Definition: cover_tree.hpp:283
size_t NumChildren() const
Get the number of children.
Definition: cover_tree.hpp:301
ElemType FurthestPointDistance() const
Get the distance to the furthest point. This is always 0 for cover trees.
Definition: cover_tree.hpp:414
ElemType & Base()
Modify the base; don&#39;t do this, you&#39;ll break everything.
Definition: cover_tree.hpp:322
Definition of the Range class, which represents a simple range with a lower and upper bound...
CoverTree & Child(const size_t index)
Modify a particular child node.
Definition: cover_tree.hpp:296
ElemType MinimumBoundDistance() const
Get the minimum distance from the center to any bound edge (this is the same as furthestDescendantDis...
Definition: cover_tree.hpp:425
MetricType & Metric() const
Get the instantiated metric.
Definition: cover_tree.hpp:434
size_t Descendant(const size_t index) const
Get the index of a particular descendant point.
ElemType FurthestDescendantDistance() const
Get the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:417
math::RangeType< ElemType > RangeDistance(const CoverTree &other) const
Return the minimum and maximum distance to another node.
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99
ElemType & ParentDistance()
Modify the distance to the parent.
Definition: cover_tree.hpp:411
const CoverTree & Child(const size_t index) const
Get a particular child node.
Definition: cover_tree.hpp:294
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
size_t NumDescendants() const
Get the number of descendant points.