lmnn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_LMNN_LMNN_HPP
13 #define MLPACK_METHODS_LMNN_LMNN_HPP
14 
15 #include <mlpack/prereqs.hpp>
17 #include <ensmallen.hpp>
18 
19 #include "lmnn_function.hpp"
20 
21 namespace mlpack {
22 namespace lmnn {
23 
53 template<typename MetricType = metric::SquaredEuclideanDistance,
54  typename OptimizerType = ens::AMSGrad>
55 class LMNN
56 {
57  public:
68  LMNN(const arma::mat& dataset,
69  const arma::Row<size_t>& labels,
70  const size_t k,
71  const MetricType metric = MetricType());
72 
73 
85  template<typename... CallbackTypes>
86  void LearnDistance(arma::mat& outputMatrix, CallbackTypes&&... callbacks);
87 
88 
90  const arma::mat& Dataset() const { return dataset; }
91 
93  const arma::Row<size_t>& Labels() const { return labels; }
94 
96  const double& Regularization() const { return regularization; }
98  double& Regularization() { return regularization; }
99 
101  const size_t& Range() const { return range; }
103  size_t& Range() { return range; }
104 
106  const size_t& K() const { return k; }
108  size_t K() { return k; }
109 
111  const OptimizerType& Optimizer() const { return optimizer; }
112  OptimizerType& Optimizer() { return optimizer; }
113 
114  private:
116  const arma::mat& dataset;
117 
119  const arma::Row<size_t>& labels;
120 
122  size_t k;
123 
125  double regularization;
126 
128  size_t range;
129 
131  MetricType metric;
132 
134  OptimizerType optimizer;
135 }; // class LMNN
136 
137 } // namespace lmnn
138 } // namespace mlpack
139 
140 // Include the implementation.
141 #include "lmnn_impl.hpp"
142 
143 #endif
const size_t & Range() const
Access the range value.
Definition: lmnn.hpp:101
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t K()
Modify the value of k.
Definition: lmnn.hpp:108
void LearnDistance(arma::mat &outputMatrix, CallbackTypes &&... callbacks)
Perform Large Margin Nearest Neighbors metric learning.
size_t & Range()
Modify the range value.
Definition: lmnn.hpp:103
An implementation of Large Margin nearest neighbor metric learning technique.
Definition: lmnn.hpp:55
LMetric< 2, false > SquaredEuclideanDistance
The squared Euclidean (L2) distance.
Definition: lmetric.hpp:107
OptimizerType & Optimizer()
Definition: lmnn.hpp:112
const OptimizerType & Optimizer() const
Get the optimizer.
Definition: lmnn.hpp:111
const double & Regularization() const
Access the regularization value.
Definition: lmnn.hpp:96
const arma::Row< size_t > & Labels() const
Get the labels reference.
Definition: lmnn.hpp:93
const size_t & K() const
Access the value of k.
Definition: lmnn.hpp:106
const arma::mat & Dataset() const
Get the dataset reference.
Definition: lmnn.hpp:90
LMNN(const arma::mat &dataset, const arma::Row< size_t > &labels, const size_t k, const MetricType metric=MetricType())
Initialize the LMNN object, passing a dataset (distance metric is learned using this dataset) and lab...
double & Regularization()
Modify the regularization value.
Definition: lmnn.hpp:98