orthogonal_init.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_INIT_RULES_ORTHOGONAL_INIT_HPP
13 #define MLPACK_METHODS_ANN_INIT_RULES_ORTHOGONAL_INIT_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
25 {
26  public:
32  OrthogonalInitialization(const double gain = 1.0) : gain(gain) { }
33 
42  template<typename eT>
43  void Initialize(arma::Mat<eT>& W, const size_t rows, const size_t cols)
44  {
45  arma::Mat<eT> V;
46  arma::Col<eT> s;
47 
48  arma::svd_econ(W, s, V, arma::randu<arma::Mat<eT> >(rows, cols));
49  W *= gain;
50  }
51 
58  template<typename eT>
59  void Initialize(arma::Mat<eT>& W)
60  {
61  arma::Mat<eT> V;
62  arma::Col<eT> s;
63 
64  arma::svd_econ(W, s, V, arma::randu<arma::Mat<eT> >(W.n_rows, W.n_cols));
65  W *= gain;
66  }
67 
77  template<typename eT>
78  void Initialize(arma::Cube<eT>& W,
79  const size_t rows,
80  const size_t cols,
81  const size_t slices)
82  {
83  if (W.is_empty())
84  W.set_size(rows, cols, slices);
85 
86  for (size_t i = 0; i < slices; ++i)
87  Initialize(W.slice(i), rows, cols);
88  }
89 
96  template<typename eT>
97  void Initialize(arma::Cube<eT>& W)
98  {
99  if (W.is_empty())
100  Log::Fatal << "Cannot initialize an empty cube." << std::endl;
101 
102  for (size_t i = 0; i < W.n_slices; ++i)
103  Initialize(W.slice(i));
104  }
105 
106  private:
108  double gain;
109 }; // class OrthogonalInitialization
110 
111 
112 } // namespace ann
113 } // namespace mlpack
114 
115 #endif
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize the elements of the specified weight 3rd order tensor with the orthogonal matrix initializ...
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements of the specified weight matrix with the orthogonal matrix initialization meth...
void Initialize(arma::Mat< eT > &W)
Initialize the elements of the specified weight matrix with the orthogonal matrix initialization meth...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
This class is used to initialize the weight matrix with the orthogonal matrix initialization.
void Initialize(arma::Cube< eT > &W)
Initialize the elements of the specified weight 3rd order tensor with the orthogonal matrix initializ...
OrthogonalInitialization(const double gain=1.0)
Initialize the orthogonal matrix initialization rule with the given gain.