dtb.hpp
Go to the documentation of this file.
1 
25 #ifndef MLPACK_METHODS_EMST_DTB_HPP
26 #define MLPACK_METHODS_EMST_DTB_HPP
27 
28 #include "dtb_stat.hpp"
29 #include "edge_pair.hpp"
30 
31 #include <mlpack/prereqs.hpp>
33 
35 
36 namespace mlpack {
37 namespace emst {
38 
76 template<
77  typename MetricType = metric::EuclideanDistance,
78  typename MatType = arma::mat,
79  template<typename TreeMetricType,
80  typename TreeStatType,
81  typename TreeMatType> class TreeType = tree::KDTree
82 >
84 {
85  public:
87  typedef TreeType<MetricType, DTBStat, MatType> Tree;
88 
89  private:
91  std::vector<size_t> oldFromNew;
93  Tree* tree;
95  const MatType& data;
97  bool ownTree;
98 
100  bool naive;
101 
103  std::vector<EdgePair> edges; // We must use vector with non-numerical types.
104 
106  UnionFind connections;
107 
109  arma::Col<size_t> neighborsInComponent;
111  arma::Col<size_t> neighborsOutComponent;
113  arma::vec neighborsDistances;
114 
116  double totalDist;
117 
119  MetricType metric;
120 
122  struct SortEdgesHelper
123  {
124  bool operator()(const EdgePair& pairA, const EdgePair& pairB)
125  {
126  return (pairA.Distance() < pairB.Distance());
127  }
128  } SortFun;
129 
130  public:
139  DualTreeBoruvka(const MatType& dataset,
140  const bool naive = false,
141  const MetricType metric = MetricType());
142 
159  DualTreeBoruvka(Tree* tree,
160  const MetricType metric = MetricType());
161 
166 
176  void ComputeMST(arma::mat& results);
177 
178  private:
182  void AddEdge(const size_t e1, const size_t e2, const double distance);
183 
187  void AddAllEdges();
188 
192  void EmitResults(arma::mat& results);
193 
198  void CleanupHelper(Tree* tree);
199 
203  void Cleanup();
204 }; // class DualTreeBoruvka
205 
206 } // namespace emst
207 } // namespace mlpack
208 
209 #include "dtb_impl.hpp"
210 
211 #endif // MLPACK_METHODS_EMST_DTB_HPP
A Union-Find data structure.
Definition: union_find.hpp:30
An edge pair is simply two indices and a distance.
Definition: edge_pair.hpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
A binary space partitioning tree, such as a KD-tree or a ball tree.
DualTreeBoruvka(const MatType &dataset, const bool naive=false, const MetricType metric=MetricType())
Create the tree from the given dataset.
TreeType< MetricType, DTBStat, MatType > Tree
Convenience typedef.
Definition: dtb.hpp:87
double Distance() const
Get the distance.
Definition: edge_pair.hpp:63
void ComputeMST(arma::mat &results)
Iteratively find the nearest neighbor of each component until the MST is complete.
LMetric< 2, true > EuclideanDistance
The Euclidean (L2) distance.
Definition: lmetric.hpp:112
~DualTreeBoruvka()
Delete the tree, if it was created inside the object.
Performs the MST calculation using the Dual-Tree Boruvka algorithm, using any type of tree...
Definition: dtb.hpp:83