lmnn_function.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_LMNN_FUNCTION_HPP
14 #define MLPACK_METHODS_LMNN_FUNCTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
18 
19 #include "constraints.hpp"
20 
21 namespace mlpack {
22 namespace lmnn {
23 
45 template<typename MetricType = metric::SquaredEuclideanDistance>
47 {
48  public:
59  LMNNFunction(const arma::mat& dataset,
60  const arma::Row<size_t>& labels,
61  size_t k,
62  double regularization,
63  size_t range,
64  MetricType metric = MetricType());
65 
66 
70  void Shuffle();
71 
79  double Evaluate(const arma::mat& transformation);
80 
93  double Evaluate(const arma::mat& transformation,
94  const size_t begin,
95  const size_t batchSize = 1);
96 
106  template<typename GradType>
107  void Gradient(const arma::mat& transformation, GradType& gradient);
108 
124  template<typename GradType>
125  void Gradient(const arma::mat& transformation,
126  const size_t begin,
127  GradType& gradient,
128  const size_t batchSize = 1);
129 
140  template<typename GradType>
141  double EvaluateWithGradient(const arma::mat& transformation,
142  GradType& gradient);
143 
159  template<typename GradType>
160  double EvaluateWithGradient(const arma::mat& transformation,
161  const size_t begin,
162  GradType& gradient,
163  const size_t batchSize = 1);
164 
166  const arma::mat& GetInitialPoint() const { return initialPoint; }
167 
172  size_t NumFunctions() const { return dataset.n_cols; }
173 
175  const arma::mat& Dataset() const { return dataset; }
176 
178  const double& Regularization() const { return regularization; }
180  double& Regularization() { return regularization; }
181 
183  const size_t& K() const { return k; }
185  size_t& K() { return k; }
186 
188  const size_t& Range() const { return range; }
190  size_t& Range() { return range; }
191 
192  private:
194  arma::mat dataset;
196  arma::Row<size_t> labels;
198  arma::mat initialPoint;
200  arma::mat transformedDataset;
202  arma::Mat<size_t> targetNeighbors;
204  arma::Mat<size_t> impostors;
206  arma::mat distance;
208  size_t k;
210  MetricType metric;
212  double regularization;
214  size_t iteration;
216  size_t range;
218  Constraints<MetricType> constraint;
220  arma::mat pCij;
222  arma::vec norm;
224  arma::cube evalOld;
226  arma::mat maxImpNorm;
228  arma::mat transformationOld;
230  std::vector<arma::mat> oldTransformationMatrices;
232  std::vector<size_t> oldTransformationCounts;
234  arma::vec lastTransformationIndices;
236  arma::uvec points;
238  bool impBounds;
244  inline void Precalculate();
246  inline void UpdateCache(const arma::mat& transformation,
247  const size_t begin,
248  const size_t batchSize);
250  inline void TransDiff(std::map<size_t, double>& transformationDiffs,
251  const arma::mat& transformation,
252  const size_t begin,
253  const size_t batchSize);
254 };
255 
256 } // namespace lmnn
257 } // namespace mlpack
258 
259 #include "lmnn_function_impl.hpp"
260 
261 #endif
size_t & Range()
Modify the value of k.
const size_t & K() const
Access the value of k.
Linear algebra utility functions, generally performed on matrices or vectors.
const double & Regularization() const
Access the regularization value.
The Large Margin Nearest Neighbors function.
void Shuffle()
Shuffle the points in the dataset.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double EvaluateWithGradient(const arma::mat &transformation, GradType &gradient)
Evaluate the LMNN objective function together with gradient for the given transformation matrix...
const arma::mat & Dataset() const
Return the dataset passed into the constructor.
double & Regularization()
Modify the regularization value.
Interface for generating distance based constraints on a given dataset, provided corresponding true l...
Definition: constraints.hpp:31
size_t NumFunctions() const
Get the number of functions the objective function can be decomposed into.
void Gradient(const arma::mat &transformation, GradType &gradient)
Evaluate the gradient of the LMNN function for the given transformation matrix.
LMNNFunction(const arma::mat &dataset, const arma::Row< size_t > &labels, size_t k, double regularization, size_t range, MetricType metric=MetricType())
Constructor for LMNNFunction class.
size_t & K()
Modify the value of k.
const arma::mat & GetInitialPoint() const
Return the initial point for the optimization.
double Evaluate(const arma::mat &transformation)
Evaluate the LMNN function for the given transformation matrix.
const size_t & Range() const
Access the value of range.