neighbor_search_rules.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
14 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
15 
17 
18 #include <queue>
19 
20 namespace mlpack {
21 namespace neighbor {
22 
34 template<typename SortPolicy, typename MetricType, typename TreeType>
36 {
37  public:
50  NeighborSearchRules(const typename TreeType::Mat& referenceSet,
51  const typename TreeType::Mat& querySet,
52  const size_t k,
53  MetricType& metric,
54  const double epsilon = 0,
55  const bool sameSet = false);
56 
64  void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
65 
74  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
75 
84  double Score(const size_t queryIndex, TreeType& referenceNode);
85 
92  size_t GetBestChild(const size_t queryIndex, TreeType& referenceNode);
93 
100  size_t GetBestChild(const TreeType& queryNode, TreeType& referenceNode);
101 
113  double Rescore(const size_t queryIndex,
114  TreeType& referenceNode,
115  const double oldScore) const;
116 
125  double Score(TreeType& queryNode, TreeType& referenceNode);
126 
138  double Rescore(TreeType& queryNode,
139  TreeType& referenceNode,
140  const double oldScore) const;
141 
143  size_t BaseCases() const { return baseCases; }
145  size_t& BaseCases() { return baseCases; }
146 
148  size_t Scores() const { return scores; }
150  size_t& Scores() { return scores; }
151 
154 
156  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
158  TraversalInfoType& TraversalInfo() { return traversalInfo; }
159 
162  size_t MinimumBaseCases() const { return k; }
163 
164  protected:
166  const typename TreeType::Mat& referenceSet;
167 
169  const typename TreeType::Mat& querySet;
170 
172  typedef std::pair<double, size_t> Candidate;
173 
175  struct CandidateCmp {
176  bool operator()(const Candidate& c1, const Candidate& c2)
177  {
178  return !SortPolicy::IsBetter(c2.first, c1.first);
179  };
180  };
181 
183  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
185 
187  std::vector<CandidateList> candidates;
188 
190  const size_t k;
191 
193  MetricType& metric;
194 
196  bool sameSet;
197 
199  const double epsilon;
200 
206  double lastBaseCase;
207 
209  size_t baseCases;
211  size_t scores;
212 
215  TraversalInfoType traversalInfo;
216 
220  double CalculateBound(TreeType& queryNode) const;
221 
229  void InsertNeighbor(const size_t queryIndex,
230  const size_t neighbor,
231  const double distance);
232 };
233 
234 } // namespace neighbor
235 } // namespace mlpack
236 
237 // Include implementation.
238 #include "neighbor_search_rules_impl.hpp"
239 
240 #endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
size_t & BaseCases()
Modify the number of base cases that have been performed.
NeighborSearchRules(const typename TreeType::Mat &referenceSet, const typename TreeType::Mat &querySet, const size_t k, MetricType &metric, const double epsilon=0, const bool sameSet=false)
Construct the NeighborSearchRules object.
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
size_t & Scores()
Modify the number of scores that have been performed.
size_t scores
The number of scores that have been performed.
Linear algebra utility functions, generally performed on matrices or vectors.
const size_t k
Number of neighbors to search for.
std::vector< CandidateList > candidates
Set of candidate neighbors for each point.
const TraversalInfoType & TraversalInfo() const
Get the traversal info.
void InsertNeighbor(const size_t queryIndex, const size_t neighbor, const double distance)
Helper function to insert a point into the list of candidate points.
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
size_t lastReferenceIndex
The last reference point BaseCase() was called with.
const double epsilon
Relative error to be considered in approximate search.
const TreeType::Mat & querySet
The query set.
size_t lastQueryIndex
The last query point BaseCase() was called with.
tree::TraversalInfo< TreeType > TraversalInfoType
Convenience typedef.
const TreeType::Mat & referenceSet
The reference set.
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
TraversalInfoType & TraversalInfo()
Modify the traversal info.
TraversalInfoType traversalInfo
Traversal info for the parent combination; this is updated by the traversal before each call to Score...
size_t BaseCases() const
Get the number of base cases that have been performed.
The NeighborSearchRules class is a template helper class used by NeighborSearch class when performing...
size_t GetBestChild(const size_t queryIndex, TreeType &referenceNode)
Get the child node with the best score.
std::pair< double, size_t > Candidate
Candidate represents a possible candidate neighbor (distance, index).
size_t Scores() const
Get the number of scores that have been performed.
bool sameSet
Denotes whether or not the reference and query sets are the same.
double CalculateBound(TreeType &queryNode) const
Recalculate the bound for a given query node.
size_t baseCases
The number of base cases that have been performed.
MetricType & metric
The instantiated metric.
std::priority_queue< Candidate, std::vector< Candidate >, CandidateCmp > CandidateList
Use a priority queue to represent the list of candidate neighbors.
size_t MinimumBaseCases() const
Get the minimum number of base cases we need to perform to have acceptable results.
bool operator()(const Candidate &c1, const Candidate &c2)
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
Compare two candidates based on the distance.
double lastBaseCase
The last base case result.
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.