dual_tree_kmeans_statistic.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_KMEANS_DTNN_STATISTIC_HPP
13 #define MLPACK_METHODS_KMEANS_DTNN_STATISTIC_HPP
14 
16 
17 namespace mlpack {
18 namespace kmeans {
19 
21  neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>
22 {
23  public:
25  neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
26  upperBound(DBL_MAX),
27  lowerBound(DBL_MAX),
28  owner(size_t(-1)),
29  pruned(size_t(-1)),
30  staticPruned(false),
31  staticUpperBoundMovement(0.0),
32  staticLowerBoundMovement(0.0),
33  centroid(),
34  trueParent(NULL)
35  {
36  // Nothing to do.
37  }
38 
39  template<typename TreeType>
40  DualTreeKMeansStatistic(TreeType& node) :
41  neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
42  upperBound(DBL_MAX),
43  lowerBound(DBL_MAX),
44  owner(size_t(-1)),
45  pruned(size_t(-1)),
46  staticPruned(false),
47  staticUpperBoundMovement(0.0),
48  staticLowerBoundMovement(0.0),
49  trueParent(node.Parent())
50  {
51  // Empirically calculate the centroid.
52  centroid.zeros(node.Dataset().n_rows);
53  for (size_t i = 0; i < node.NumPoints(); ++i)
54  {
55  // Correct handling of cover tree: don't double-count the point which
56  // appears in the children.
58  node.NumChildren() > 0)
59  continue;
60  centroid += node.Dataset().col(node.Point(i));
61  }
62 
63  for (size_t i = 0; i < node.NumChildren(); ++i)
64  centroid += node.Child(i).NumDescendants() *
65  node.Child(i).Stat().Centroid();
66 
67  centroid /= node.NumDescendants();
68 
69  // Set the true children correctly.
70  trueChildren.resize(node.NumChildren());
71  for (size_t i = 0; i < node.NumChildren(); ++i)
72  trueChildren[i] = &node.Child(i);
73  }
74 
75  double UpperBound() const { return upperBound; }
76  double& UpperBound() { return upperBound; }
77 
78  double LowerBound() const { return lowerBound; }
79  double& LowerBound() { return lowerBound; }
80 
81  const arma::vec& Centroid() const { return centroid; }
82  arma::vec& Centroid() { return centroid; }
83 
84  size_t Owner() const { return owner; }
85  size_t& Owner() { return owner; }
86 
87  size_t Pruned() const { return pruned; }
88  size_t& Pruned() { return pruned; }
89 
90  bool StaticPruned() const { return staticPruned; }
91  bool& StaticPruned() { return staticPruned; }
92 
93  double StaticUpperBoundMovement() const { return staticUpperBoundMovement; }
94  double& StaticUpperBoundMovement() { return staticUpperBoundMovement; }
95 
96  double StaticLowerBoundMovement() const { return staticLowerBoundMovement; }
97  double& StaticLowerBoundMovement() { return staticLowerBoundMovement; }
98 
99  void* TrueParent() const { return trueParent; }
100  void*& TrueParent() { return trueParent; }
101 
102  void* TrueChild(const size_t i) const { return trueChildren[i]; }
103  void*& TrueChild(const size_t i) { return trueChildren[i]; }
104 
105  size_t NumTrueChildren() const { return trueChildren.size(); }
106 
107  private:
108  double upperBound;
109  double lowerBound;
110  size_t owner;
111  size_t pruned;
112  bool staticPruned;
113  double staticUpperBoundMovement;
114  double staticLowerBoundMovement;
115  arma::vec centroid;
116  void* trueParent;
117  std::vector<void*> trueChildren;
118 };
119 
120 } // namespace kmeans
121 } // namespace mlpack
122 
123 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Extra data for each node in the tree.
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
NeighborSearchStat()
Initialize the statistic with the worst possible distance according to our sorting policy...