dtb_rules.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_EMST_DTB_RULES_HPP
13 #define MLPACK_METHODS_EMST_DTB_RULES_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
18 
19 namespace mlpack {
20 namespace emst {
21 
22 template<typename MetricType, typename TreeType>
23 class DTBRules
24 {
25  public:
26  DTBRules(const arma::mat& dataSet,
27  UnionFind& connections,
28  arma::vec& neighborsDistances,
29  arma::Col<size_t>& neighborsInComponent,
30  arma::Col<size_t>& neighborsOutComponent,
31  MetricType& metric);
32 
33  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
34 
43  double Score(const size_t queryIndex, TreeType& referenceNode);
44 
56  double Rescore(const size_t queryIndex,
57  TreeType& referenceNode,
58  const double oldScore);
59 
68  double Score(TreeType& queryNode, TreeType& referenceNode);
69 
81  double Rescore(TreeType& queryNode,
82  TreeType& referenceNode,
83  const double oldScore) const;
84 
86 
87  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
88  TraversalInfoType& TraversalInfo() { return traversalInfo; }
89 
91  size_t BaseCases() const { return baseCases; }
93  size_t& BaseCases() { return baseCases; }
94 
96  size_t Scores() const { return scores; }
98  size_t& Scores() { return scores; }
99 
100  private:
102  const arma::mat& dataSet;
103 
105  UnionFind& connections;
106 
108  arma::vec& neighborsDistances;
109 
112  arma::Col<size_t>& neighborsInComponent;
113 
116  arma::Col<size_t>& neighborsOutComponent;
117 
119  MetricType& metric;
120 
124  inline double CalculateBound(TreeType& queryNode) const;
125 
126  TraversalInfoType traversalInfo;
127 
129  size_t baseCases;
131  size_t scores;
132 }; // class DTBRules
133 
134 } // namespace emst
135 } // namespace mlpack
136 
137 #include "dtb_rules_impl.hpp"
138 
139 #endif
A Union-Find data structure.
Definition: union_find.hpp:30
size_t BaseCases() const
Get the number of base cases performed.
Definition: dtb_rules.hpp:91
DTBRules(const arma::mat &dataSet, UnionFind &connections, arma::vec &neighborsDistances, arma::Col< size_t > &neighborsInComponent, arma::Col< size_t > &neighborsOutComponent, MetricType &metric)
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t & Scores()
Modify the number of node combinations that have been scored.
Definition: dtb_rules.hpp:98
const TraversalInfoType & TraversalInfo() const
Definition: dtb_rules.hpp:87
tree::TraversalInfo< TreeType > TraversalInfoType
Definition: dtb_rules.hpp:85
size_t Scores() const
Get the number of node combinations that have been scored.
Definition: dtb_rules.hpp:96
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
TraversalInfoType & TraversalInfo()
Definition: dtb_rules.hpp:88
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
size_t & BaseCases()
Modify the number of base cases performed.
Definition: dtb_rules.hpp:93
double BaseCase(const size_t queryIndex, const size_t referenceIndex)