ra_search_rules.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
15 #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
16 
18 
19 #include <queue>
20 
21 namespace mlpack {
22 namespace neighbor {
23 
32 template<typename SortPolicy, typename MetricType, typename TreeType>
34 {
35  public:
57  RASearchRules(const arma::mat& referenceSet,
58  const arma::mat& querySet,
59  const size_t k,
60  MetricType& metric,
61  const double tau = 5,
62  const double alpha = 0.95,
63  const bool naive = false,
64  const bool sampleAtLeaves = false,
65  const bool firstLeafExact = false,
66  const size_t singleSampleLimit = 20,
67  const bool sameSet = false);
68 
76  void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
77 
85  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
86 
109  double Score(const size_t queryIndex, TreeType& referenceNode);
110 
134  double Score(const size_t queryIndex,
135  TreeType& referenceNode,
136  const double baseCaseResult);
137 
155  double Rescore(const size_t queryIndex,
156  TreeType& referenceNode,
157  const double oldScore);
158 
177  double Score(TreeType& queryNode, TreeType& referenceNode);
178 
199  double Score(TreeType& queryNode,
200  TreeType& referenceNode,
201  const double baseCaseResult);
202 
225  double Rescore(TreeType& queryNode,
226  TreeType& referenceNode,
227  const double oldScore);
228 
229 
230  size_t NumDistComputations() { return numDistComputations; }
232  {
233  if (numSamplesMade.n_elem == 0)
234  return 0;
235  else
236  return arma::sum(numSamplesMade);
237  }
238 
240 
241  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
242  TraversalInfoType& TraversalInfo() { return traversalInfo; }
243 
247  size_t MinimumBaseCases() const { return k; }
248 
249  private:
251  const arma::mat& referenceSet;
252 
254  const arma::mat& querySet;
255 
257  typedef std::pair<double, size_t> Candidate;
258 
260  struct CandidateCmp {
261  bool operator()(const Candidate& c1, const Candidate& c2)
262  {
263  return !SortPolicy::IsBetter(c2.first, c1.first);
264  };
265  };
266 
268  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
269  CandidateList;
270 
272  std::vector<CandidateList> candidates;
273 
275  const size_t k;
276 
278  MetricType& metric;
279 
281  bool sampleAtLeaves;
282 
284  bool firstLeafExact;
285 
287  size_t singleSampleLimit;
288 
290  size_t numSamplesReqd;
291 
293  arma::Col<size_t> numSamplesMade;
294 
296  double samplingRatio;
297 
299  size_t numDistComputations;
300 
302  bool sameSet;
303 
304  TraversalInfoType traversalInfo;
305 
313  void InsertNeighbor(const size_t queryIndex,
314  const size_t neighbor,
315  const double distance);
316 
320  double Score(const size_t queryIndex,
321  TreeType& referenceNode,
322  const double distance,
323  const double bestDistance);
324 
328  double Score(TreeType& queryNode,
329  TreeType& referenceNode,
330  const double distance,
331  const double bestDistance);
332 
333  static_assert(tree::TreeTraits<TreeType>::UniqueNumDescendants, "TreeType "
334  "must provide a unique number of descendants points.");
335 }; // class RASearchRules
336 
337 } // namespace neighbor
338 } // namespace mlpack
339 
340 // Include implementation.
341 #include "ra_search_rules_impl.hpp"
342 
343 #endif // MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
RASearchRules(const arma::mat &referenceSet, const arma::mat &querySet, const size_t k, MetricType &metric, const double tau=5, const double alpha=0.95, const bool naive=false, const bool sampleAtLeaves=false, const bool firstLeafExact=false, const size_t singleSampleLimit=20, const bool sameSet=false)
Construct the RASearchRules object.
size_t MinimumBaseCases() const
Get the minimum number of base cases that must be performed for each query point for an acceptable re...
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
Linear algebra utility functions, generally performed on matrices or vectors.
tree::TraversalInfo< TreeType > TraversalInfoType
const TraversalInfoType & TraversalInfo() const
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
see subsection cli_alt_reg_tut Alternate DET regularization The usual regularized error f $R_ alpha(t)\f$ of a node \f $t\f$ is given by
Definition: det.txt:344
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
TraversalInfoType & TraversalInfo()
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
The RASearchRules class is a template helper class used by RASearch class when performing rank-approx...