ns_model.hpp
Go to the documentation of this file.
1 
16 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
17 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
18 
24 #include "neighbor_search.hpp"
25 
26 namespace mlpack {
27 namespace neighbor {
28 
36 {
37  public:
41 
44  virtual NSWrapperBase* Clone() const = 0;
45 
47  virtual ~NSWrapperBase() { }
48 
50  virtual const arma::mat& Dataset() const = 0;
51 
53  virtual NeighborSearchMode SearchMode() const = 0;
55  virtual NeighborSearchMode& SearchMode() = 0;
56 
58  virtual double Epsilon() const = 0;
60  virtual double& Epsilon() = 0;
61 
63  virtual void Train(util::Timers& timers,
64  arma::mat&& referenceSet,
65  const size_t leafSize,
66  const double tau,
67  const double rho) = 0;
68 
71  virtual void Search(util::Timers& timers,
72  arma::mat&& querySet,
73  const size_t k,
74  arma::Mat<size_t>& neighbors,
75  arma::mat& distances,
76  const size_t leafSize,
77  const double rho) = 0;
78 
81  virtual void Search(util::Timers& timers,
82  const size_t k,
83  arma::Mat<size_t>& neighbors,
84  arma::mat& distances) = 0;
85 };
86 
90 template<typename SortPolicy,
91  template<typename TreeMetricType,
92  typename TreeStatType,
93  typename TreeMatType> class TreeType,
94  template<typename RuleType> class DualTreeTraversalType =
97  arma::mat>::template DualTreeTraverser,
98  template<typename RuleType> class SingleTreeTraversalType =
100  NeighborSearchStat<SortPolicy>,
101  arma::mat>::template SingleTreeTraverser>
102 class NSWrapper : public NSWrapperBase
103 {
104  public:
107  NSWrapper(const NeighborSearchMode searchMode,
108  const double epsilon) :
109  ns(searchMode, epsilon)
110  {
111  // Nothing else to do.
112  }
113 
115  virtual ~NSWrapper() { }
116 
119  virtual NSWrapper* Clone() const { return new NSWrapper(*this); }
120 
122  const arma::mat& Dataset() const { return ns.ReferenceSet(); }
123 
125  NeighborSearchMode SearchMode() const { return ns.SearchMode(); }
127  NeighborSearchMode& SearchMode() { return ns.SearchMode(); }
128 
130  double Epsilon() const { return ns.Epsilon(); }
132  double& Epsilon() { return ns.Epsilon(); }
133 
136  virtual void Train(util::Timers& timers,
137  arma::mat&& referenceSet,
138  const size_t /* leafSize */,
139  const double /* tau */,
140  const double /* rho */);
141 
144  virtual void Search(util::Timers& timers,
145  arma::mat&& querySet,
146  const size_t k,
147  arma::Mat<size_t>& neighbors,
148  arma::mat& distances,
149  const size_t /* leafSize */,
150  const double /* rho */);
151 
154  virtual void Search(util::Timers& timers,
155  const size_t k,
156  arma::Mat<size_t>& neighbors,
157  arma::mat& distances);
158 
160  template<typename Archive>
161  void serialize(Archive& ar, const uint32_t /* version */)
162  {
163  ar(CEREAL_NVP(ns));
164  }
165 
166  protected:
167  // Convenience typedef for the neighbor search type held by this class.
168  typedef NeighborSearch<SortPolicy,
170  arma::mat,
171  TreeType,
172  DualTreeTraversalType,
173  SingleTreeTraversalType> NSType;
174 
176  NSType ns;
177 };
178 
184 template<typename SortPolicy,
185  template<typename TreeMetricType,
186  typename TreeStatType,
187  typename TreeMatType> class TreeType,
188  template<typename RuleType> class DualTreeTraversalType =
189  TreeType<metric::EuclideanDistance,
190  NeighborSearchStat<SortPolicy>,
191  arma::mat>::template DualTreeTraverser,
192  template<typename RuleType> class SingleTreeTraversalType =
193  TreeType<metric::EuclideanDistance,
194  NeighborSearchStat<SortPolicy>,
195  arma::mat>::template SingleTreeTraverser>
196 class LeafSizeNSWrapper :
197  public NSWrapper<SortPolicy,
198  TreeType,
199  DualTreeTraversalType,
200  SingleTreeTraversalType>
201 {
202  public:
206  const double epsilon) :
207  NSWrapper<SortPolicy,
208  TreeType,
209  DualTreeTraversalType,
210  SingleTreeTraversalType>(searchMode, epsilon)
211  {
212  // Nothing to do.
213  }
214 
216  virtual ~LeafSizeNSWrapper() { }
217 
219  virtual LeafSizeNSWrapper* Clone() const
220  {
221  return new LeafSizeNSWrapper(*this);
222  }
223 
226  virtual void Train(util::Timers& timers,
227  arma::mat&& referenceSet,
228  const size_t leafSize,
229  const double /* tau */,
230  const double /* rho */);
231 
234  virtual void Search(util::Timers& timers,
235  arma::mat&& querySet,
236  const size_t k,
237  arma::Mat<size_t>& neighbors,
238  arma::mat& distances,
239  const size_t leafSize,
240  const double /* rho */);
241 
243  template<typename Archive>
244  void serialize(Archive& ar, const uint32_t /* version */)
245  {
246  ar(CEREAL_NVP(ns));
247  }
248 
249  protected:
250  using NSWrapper<SortPolicy,
251  TreeType,
252  DualTreeTraversalType,
253  SingleTreeTraversalType>::ns;
254 };
255 
260 template<typename SortPolicy>
262  public NSWrapper<
263  SortPolicy,
264  tree::SPTree,
265  tree::SPTree<metric::EuclideanDistance,
266  NeighborSearchStat<SortPolicy>,
267  arma::mat>::template DefeatistDualTreeTraverser,
268  tree::SPTree<metric::EuclideanDistance,
269  NeighborSearchStat<SortPolicy>,
270  arma::mat>::template DefeatistSingleTreeTraverser>
271 {
272  public:
275  const double epsilon) :
276  NSWrapper<
277  SortPolicy,
278  tree::SPTree,
279  tree::SPTree<metric::EuclideanDistance,
280  NeighborSearchStat<SortPolicy>,
281  arma::mat>::template DefeatistDualTreeTraverser,
282  tree::SPTree<metric::EuclideanDistance,
283  NeighborSearchStat<SortPolicy>,
284  arma::mat>::template DefeatistSingleTreeTraverser>(
285  searchMode, epsilon)
286  {
287  // Nothing to do.
288  }
289 
291  virtual ~SpillNSWrapper() { }
292 
294  virtual SpillNSWrapper* Clone() const { return new SpillNSWrapper(*this); }
295 
297  virtual void Train(util::Timers& timers,
298  arma::mat&& referenceSet,
299  const size_t leafSize,
300  const double tau,
301  const double rho);
302 
305  virtual void Search(util::Timers& timers,
306  arma::mat&& querySet,
307  const size_t k,
308  arma::Mat<size_t>& neighbors,
309  arma::mat& distances,
310  const size_t leafSize,
311  const double rho);
312 
314  template<typename Archive>
315  void serialize(Archive& ar, const uint32_t /* version */)
316  {
317  ar(CEREAL_NVP(ns));
318  }
319 
320  protected:
321  using NSWrapper<
322  SortPolicy,
323  tree::SPTree,
324  tree::SPTree<metric::EuclideanDistance,
325  NeighborSearchStat<SortPolicy>,
326  arma::mat>::template DefeatistDualTreeTraverser,
327  tree::SPTree<metric::EuclideanDistance,
328  NeighborSearchStat<SortPolicy>,
329  arma::mat>::template DefeatistSingleTreeTraverser>::ns;
330 };
331 
342 template<typename SortPolicy>
343 class NSModel
344 {
345  public:
348  {
363  OCTREE
364  };
365 
366  private:
368  TreeTypes treeType;
369 
371  bool randomBasis;
373  arma::mat q;
374 
375  size_t leafSize;
376  double tau;
377  double rho;
378 
383  NSWrapperBase* nSearch;
384 
385  public:
394  NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false);
395 
401  NSModel(const NSModel& other);
402 
408  NSModel(NSModel&& other);
409 
415  NSModel& operator=(const NSModel& other);
416 
422  NSModel& operator=(NSModel&& other);
423 
425  ~NSModel();
426 
428  template<typename Archive>
429  void serialize(Archive& ar, const uint32_t /* version */);
430 
432  const arma::mat& Dataset() const;
433 
437 
439  size_t LeafSize() const { return leafSize; }
440  size_t& LeafSize() { return leafSize; }
441 
443  double Tau() const { return tau; }
444  double& Tau() { return tau; }
445 
447  double Rho() const { return rho; }
448  double& Rho() { return rho; }
449 
451  double Epsilon() const;
452  double& Epsilon();
453 
455  TreeTypes TreeType() const { return treeType; }
456  TreeTypes& TreeType() { return treeType; }
457 
459  bool RandomBasis() const { return randomBasis; }
460  bool& RandomBasis() { return randomBasis; }
461 
463  void InitializeModel(const NeighborSearchMode searchMode,
464  const double epsilon);
465 
467  void BuildModel(util::Timers& timers,
468  arma::mat&& referenceSet,
469  const NeighborSearchMode searchMode,
470  const double epsilon = 0);
471 
473  void Search(util::Timers& timers,
474  arma::mat&& querySet,
475  const size_t k,
476  arma::Mat<size_t>& neighbors,
477  arma::mat& distances);
478 
480  void Search(util::Timers& timers,
481  const size_t k,
482  arma::Mat<size_t>& neighbors,
483  arma::mat& distances);
484 
486  std::string TreeName() const;
487 };
488 
489 } // namespace neighbor
490 } // namespace mlpack
491 
492 // Include implementation.
493 #include "ns_model_impl.hpp"
494 
495 #endif
virtual ~NSWrapper()
Delete the NSWrapper object.
Definition: ns_model.hpp:115
double Epsilon() const
Get epsilon, the approximation parameter.
Definition: ns_model.hpp:130
Linear algebra utility functions, generally performed on matrices or vectors.
bool RandomBasis() const
Expose randomBasis.
Definition: ns_model.hpp:459
TreeTypes
Enum type to identify each accepted tree type.
Definition: ns_model.hpp:347
NSWrapper is a wrapper class for most NeighborSearch types.
Definition: ns_model.hpp:102
void serialize(Archive &ar, const uint32_t)
Serialize the NeighborSearch model.
Definition: ns_model.hpp:315
NSWrapper(const NeighborSearchMode searchMode, const double epsilon)
Construct the NSWrapper object, initializing the internally-held NeighborSearch object.
Definition: ns_model.hpp:107
NeighborSearchMode & SearchMode()
Modify the search mode.
Definition: ns_model.hpp:127
Extra data for each node in the tree.
NSType ns
The instantiated NeighborSearch object that we are wrapping.
Definition: ns_model.hpp:176
LeafSizeNSWrapper wraps any NeighborSearch types that take a leaf size for tree construction.
virtual ~LeafSizeNSWrapper()
Delete the LeafSizeNSWrapper.
Definition: ns_model.hpp:216
virtual NSWrapperBase * Clone() const =0
Create a new NSWrapperBase that is the same as this one.
virtual LeafSizeNSWrapper * Clone() const
Return a copy of the LeafSizeNSWrapper.
Definition: ns_model.hpp:219
The NeighborSearch class is a template class for performing distance-based neighbor searches...
TreeTypes TreeType() const
Expose treeType.
Definition: ns_model.hpp:455
NeighborSearch< SortPolicy, metric::EuclideanDistance, arma::mat, TreeType, DualTreeTraversalType, SingleTreeTraversalType > NSType
Definition: ns_model.hpp:173
NSWrapperBase is a base wrapper class for holding all NeighborSearch types supported by NSModel...
Definition: ns_model.hpp:35
NSWrapperBase()
Create the NSWrapperBase object.
Definition: ns_model.hpp:40
size_t LeafSize() const
Expose LeafSize.
Definition: ns_model.hpp:439
SpillTree< MetricType, StatisticType, MatType, AxisOrthogonalHyperplane, MidpointSpaceSplit > SPTree
The hybrid spill tree.
Definition: typedef.hpp:62
virtual void Train(util::Timers &timers, arma::mat &&referenceSet, const size_t leafSize, const double tau, const double rho)=0
Train the NeighborSearch model with the given parameters.
virtual NSWrapper * Clone() const
Create a copy of this NSWrapper object.
Definition: ns_model.hpp:119
TreeTypes & TreeType()
Definition: ns_model.hpp:456
double & Epsilon()
Modify epsilon, the approximation parameter.
Definition: ns_model.hpp:132
The NSModel class provides an easy way to serialize a model, abstracts away the different types of tr...
Definition: ns_model.hpp:343
double Tau() const
Expose Tau.
Definition: ns_model.hpp:443
SpillNSWrapper(const NeighborSearchMode searchMode, const double epsilon)
Construct the SpillNSWrapper.
Definition: ns_model.hpp:274
virtual NeighborSearchMode SearchMode() const =0
Get the search mode.
const arma::mat & Dataset() const
Get a reference to the reference set.
Definition: ns_model.hpp:122
virtual const arma::mat & Dataset() const =0
Return a reference to the dataset.
virtual SpillNSWrapper * Clone() const
Return a copy of the SpillNSWrapper.
Definition: ns_model.hpp:294
virtual double Epsilon() const =0
Get the approximation parameter epsilon.
The SpillNSWrapper class wraps the NeighborSearch class when the spill tree is used.
Definition: ns_model.hpp:261
virtual void Search(util::Timers &timers, arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, const size_t leafSize, const double rho)=0
Perform bichromatic neighbor search (i.e.
LMetric< 2, true > EuclideanDistance
The Euclidean (L2) distance.
Definition: lmetric.hpp:112
double Rho() const
Expose Rho.
Definition: ns_model.hpp:447
virtual ~SpillNSWrapper()
Destruct the SpillNSWrapper.
Definition: ns_model.hpp:291
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
void serialize(Archive &ar, const uint32_t)
Serialize the NeighborSearch model.
Definition: ns_model.hpp:244
LeafSizeNSWrapper(const NeighborSearchMode searchMode, const double epsilon)
Construct the LeafSizeNSWrapper by delegating to the NSWrapper constructor.
Definition: ns_model.hpp:205
virtual ~NSWrapperBase()
Destruct the NSWrapperBase (nothing to do).
Definition: ns_model.hpp:47
NeighborSearchMode SearchMode() const
Get the search mode.
Definition: ns_model.hpp:125
void serialize(Archive &ar, const uint32_t)
Serialize the NeighborSearch model.
Definition: ns_model.hpp:161