14 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP 15 #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP 32 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
58 const arma::mat& querySet,
62 const double alpha = 0.95,
63 const bool naive =
false,
64 const bool sampleAtLeaves =
false,
65 const bool firstLeafExact =
false,
66 const size_t singleSampleLimit = 20,
67 const bool sameSet =
false);
76 void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
85 double BaseCase(
const size_t queryIndex,
const size_t referenceIndex);
109 double Score(
const size_t queryIndex, TreeType& referenceNode);
134 double Score(
const size_t queryIndex,
135 TreeType& referenceNode,
136 const double baseCaseResult);
155 double Rescore(
const size_t queryIndex,
156 TreeType& referenceNode,
157 const double oldScore);
177 double Score(TreeType& queryNode, TreeType& referenceNode);
199 double Score(TreeType& queryNode,
200 TreeType& referenceNode,
201 const double baseCaseResult);
225 double Rescore(TreeType& queryNode,
226 TreeType& referenceNode,
227 const double oldScore);
233 if (numSamplesMade.n_elem == 0)
236 return arma::sum(numSamplesMade);
251 const arma::mat& referenceSet;
254 const arma::mat& querySet;
257 typedef std::pair<double, size_t> Candidate;
260 struct CandidateCmp {
261 bool operator()(
const Candidate& c1,
const Candidate& c2)
263 return !SortPolicy::IsBetter(c2.first, c1.first);
268 typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
272 std::vector<CandidateList> candidates;
287 size_t singleSampleLimit;
290 size_t numSamplesReqd;
293 arma::Col<size_t> numSamplesMade;
296 double samplingRatio;
299 size_t numDistComputations;
304 TraversalInfoType traversalInfo;
313 void InsertNeighbor(
const size_t queryIndex,
314 const size_t neighbor,
315 const double distance);
320 double Score(
const size_t queryIndex,
321 TreeType& referenceNode,
322 const double distance,
323 const double bestDistance);
328 double Score(TreeType& queryNode,
329 TreeType& referenceNode,
330 const double distance,
331 const double bestDistance);
334 "must provide a unique number of descendants points.");
341 #include "ra_search_rules_impl.hpp" 343 #endif // MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP RASearchRules(const arma::mat &referenceSet, const arma::mat &querySet, const size_t k, MetricType &metric, const double tau=5, const double alpha=0.95, const bool naive=false, const bool sampleAtLeaves=false, const bool firstLeafExact=false, const size_t singleSampleLimit=20, const bool sameSet=false)
Construct the RASearchRules object.
size_t MinimumBaseCases() const
Get the minimum number of base cases that must be performed for each query point for an acceptable re...
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
Linear algebra utility functions, generally performed on matrices or vectors.
tree::TraversalInfo< TreeType > TraversalInfoType
const TraversalInfoType & TraversalInfo() const
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
size_t NumDistComputations()
size_t NumEffectiveSamples()
see subsection cli_alt_reg_tut Alternate DET regularization The usual regularized error f $R_ alpha(t)\f$ of a node \f $t\f$ is given by
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
TraversalInfoType & TraversalInfo()
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
The RASearchRules class is a template helper class used by RASearch class when performing rank-approx...