lsh_search.hpp
Go to the documentation of this file.
1 
47 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
48 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
49 
50 #include <mlpack/prereqs.hpp>
51 
54 
55 #include <queue>
56 
57 namespace mlpack {
58 namespace neighbor {
59 
68 template<
69  typename SortPolicy = NearestNeighborSort,
70  typename MatType = arma::mat
71 >
72 class LSHSearch
73 {
74  public:
97  LSHSearch(MatType referenceSet,
98  const arma::cube& projections,
99  const double hashWidth = 0.0,
100  const size_t secondHashSize = 99901,
101  const size_t bucketSize = 500);
102 
125  LSHSearch(MatType referenceSet,
126  const size_t numProj,
127  const size_t numTables,
128  const double hashWidth = 0.0,
129  const size_t secondHashSize = 99901,
130  const size_t bucketSize = 500);
131 
136  LSHSearch();
137 
143  LSHSearch(const LSHSearch& other);
144 
150  LSHSearch(LSHSearch&& other);
151 
157  LSHSearch& operator=(const LSHSearch& other);
158 
164  LSHSearch& operator=(LSHSearch&& other);
165 
191  void Train(MatType referenceSet,
192  const size_t numProj,
193  const size_t numTables,
194  const double hashWidth = 0.0,
195  const size_t secondHashSize = 99901,
196  const size_t bucketSize = 500,
197  const arma::cube& projection = arma::cube());
198 
220  void Search(const MatType& querySet,
221  const size_t k,
222  arma::Mat<size_t>& resultingNeighbors,
223  arma::mat& distances,
224  const size_t numTablesToSearch = 0,
225  const size_t T = 0);
226 
246  void Search(const size_t k,
247  arma::Mat<size_t>& resultingNeighbors,
248  arma::mat& distances,
249  const size_t numTablesToSearch = 0,
250  size_t T = 0);
251 
261  static double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
262  const arma::Mat<size_t>& realNeighbors);
263 
270  template<typename Archive>
271  void serialize(Archive& ar, const uint32_t version);
272 
274  size_t DistanceEvaluations() const { return distanceEvaluations; }
276  size_t& DistanceEvaluations() { return distanceEvaluations; }
277 
279  const MatType& ReferenceSet() const { return referenceSet; }
280 
282  size_t NumProjections() const { return projections.n_slices; }
283 
285  const arma::mat& Offsets() const { return offsets; }
286 
288  const arma::vec& SecondHashWeights() const { return secondHashWeights; }
289 
291  size_t BucketSize() const { return bucketSize; }
292 
294  const std::vector<arma::Col<size_t>>& SecondHashTable() const
295  { return secondHashTable; }
296 
298  const arma::cube& Projections() { return projections; }
299 
301  void Projections(const arma::cube& projTables)
302  {
303  // Simply call Train() with the given projection tables.
304  Train(referenceSet, numProj, numTables, hashWidth, secondHashSize,
305  bucketSize, projTables);
306  }
307 
308  private:
324  template<typename VecType>
325  void ReturnIndicesFromTable(const VecType& queryPoint,
326  arma::uvec& referenceIndices,
327  size_t numTablesToSearch,
328  const size_t T) const;
329 
343  void BaseCase(const size_t queryIndex,
344  const arma::uvec& referenceIndices,
345  const size_t k,
346  arma::Mat<size_t>& neighbors,
347  arma::mat& distances) const;
348 
363  void BaseCase(const size_t queryIndex,
364  const arma::uvec& referenceIndices,
365  const size_t k,
366  const MatType& querySet,
367  arma::Mat<size_t>& neighbors,
368  arma::mat& distances) const;
369 
384  void GetAdditionalProbingBins(const arma::vec& queryCode,
385  const arma::vec& queryCodeNotFloored,
386  const size_t T,
387  arma::mat& additionalProbingBins) const;
388 
396  double PerturbationScore(const std::vector<bool>& A,
397  const arma::vec& scores) const;
398 
406  bool PerturbationShift(std::vector<bool>& A) const;
407 
416  bool PerturbationExpand(std::vector<bool>& A) const;
417 
425  bool PerturbationValid(const std::vector<bool>& A) const;
426 
428  MatType referenceSet;
429 
431  size_t numProj;
433  size_t numTables;
434 
436  arma::cube projections; // should be [numProj x dims] x numTables slices
437 
439  arma::mat offsets; // should be numProj x numTables
440 
442  double hashWidth;
443 
445  size_t secondHashSize;
446 
448  arma::vec secondHashWeights;
449 
451  size_t bucketSize;
452 
455  std::vector<arma::Col<size_t>> secondHashTable;
456 
459  arma::Col<size_t> bucketContentSize;
460 
463  arma::Col<size_t> bucketRowInHashTable;
464 
466  size_t distanceEvaluations;
467 
469  typedef std::pair<double, size_t> Candidate;
470 
472  struct CandidateCmp {
473  bool operator()(const Candidate& c1, const Candidate& c2)
474  {
475  return !SortPolicy::IsBetter(c2.first, c1.first);
476  };
477  };
478 
480  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
481  CandidateList;
482 }; // class LSHSearch
483 
484 } // namespace neighbor
485 } // namespace mlpack
486 
487 // Include implementation.
488 #include "lsh_search_impl.hpp"
489 
490 #endif
const std::vector< arma::Col< size_t > > & SecondHashTable() const
Get the second hash table.
Definition: lsh_search.hpp:294
const arma::mat & Offsets() const
Get the offsets &#39;b&#39; for each of the projections. (One &#39;b&#39; per column.)
Definition: lsh_search.hpp:285
Linear algebra utility functions, generally performed on matrices or vectors.
size_t DistanceEvaluations() const
Return the number of distance evaluations performed.
Definition: lsh_search.hpp:274
void serialize(Archive &ar, const uint32_t version)
Serialize the LSH model.
const arma::cube & Projections()
Get the projection tables.
Definition: lsh_search.hpp:298
The core includes that mlpack expects; standard C++ includes and Armadillo.
LSHSearch()
Create an untrained LSH model.
The LSHSearch class; this class builds a hash on the reference set and uses this hash to compute the ...
Definition: lsh_search.hpp:72
size_t NumProjections() const
Get the number of projections.
Definition: lsh_search.hpp:282
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &resultingNeighbors, arma::mat &distances, const size_t numTablesToSearch=0, const size_t T=0)
Compute the nearest neighbors of the points in the given query set and store the output in the given ...
size_t BucketSize() const
Get the bucket size of the second hash.
Definition: lsh_search.hpp:291
static double ComputeRecall(const arma::Mat< size_t > &foundNeighbors, const arma::Mat< size_t > &realNeighbors)
Compute the recall (% of neighbors found) given the neighbors returned by LSHSearch::Search and a "gr...
void Projections(const arma::cube &projTables)
Change the projection tables (this retrains the LSH model).
Definition: lsh_search.hpp:301
size_t & DistanceEvaluations()
Modify the number of distance evaluations performed.
Definition: lsh_search.hpp:276
const MatType & ReferenceSet() const
Return the reference dataset.
Definition: lsh_search.hpp:279
LSHSearch & operator=(const LSHSearch &other)
Copy the given LSH model.
void Train(MatType referenceSet, const size_t numProj, const size_t numTables, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500, const arma::cube &projection=arma::cube())
Train the LSH model on the given dataset.
const arma::vec & SecondHashWeights() const
Get the weights of the second hash.
Definition: lsh_search.hpp:288