cosine_search.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_CF_COSINE_SEARCH_HPP
13 #define MLPACK_METHODS_CF_COSINE_SEARCH_HPP
14 
15 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace cf {
20 
45 {
46  public:
53  CosineSearch(const arma::mat& referenceSet)
54  {
55  // Normalize all vectors to unit length.
56  arma::mat normalizedSet = arma::normalise(referenceSet, 2, 0);
57 
58  neighborSearch.Train(std::move(normalizedSet));
59  }
60 
70  void Search(const arma::mat& query, const size_t k,
71  arma::Mat<size_t>& neighbors, arma::mat& similarities)
72  {
73  // Normalize query vectors to unit length.
74  arma::mat normalizedQuery = arma::normalise(query, 2, 0);
75 
76  neighborSearch.Search(normalizedQuery, k, neighbors, similarities);
77 
78  // Resulting similarities from Search() are Euclidean distance.
79  // For unit vectors a and b, cos(a, b) = 1 - dis(a, b) ^ 2 / 2,
80  // where dis(a, b) is Euclidean distance.
81  // Furthermore, we restrict the range of similarity to be [0, 1]:
82  // similarities = (cos(a,b) + 1) / 2.0. As a result we have the following
83  // formula.
84  similarities = 1 - arma::pow(similarities, 2) / 4.0;
85  }
86 
87  private:
89  neighbor::KNN neighborSearch;
90 };
91 
92 } // namespace cf
93 } // namespace mlpack
94 
95 #endif
Nearest neighbor search with cosine distance.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
The NeighborSearch class is a template class for performing distance-based neighbor searches...
CosineSearch(const arma::mat &referenceSet)
Constructor with reference set.
void Search(const arma::mat &query, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &similarities)
Given a set of query points, find the nearest k neighbors, and return similarities.
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
For each point in the query set, compute the nearest neighbors and store the output in the given matr...
void Train(MatType referenceSet)
Set the reference set to a new reference set, and build a tree if necessary.