glorot_init.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_ANN_INIT_RULES_GLOROT_INIT_HPP
15 #define MLPACK_METHODS_ANN_INIT_RULES_GLOROT_INIT_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 #include "random_init.hpp"
19 #include "gaussian_init.hpp"
20 
21 using namespace mlpack::math;
22 
23 namespace mlpack {
24 namespace ann {
25 
54 template<bool Uniform = true>
56 {
57  public:
62  {
63  // Nothing to do here.
64  }
65 
73  template<typename eT>
74  void Initialize(arma::Mat<eT>& W,
75  const size_t rows,
76  const size_t cols);
77 
83  template<typename eT>
84  void Initialize(arma::Mat<eT>& W);
85 
95  template<typename eT>
96  void Initialize(arma::Cube<eT>& W,
97  const size_t rows,
98  const size_t cols,
99  const size_t slices);
100 
107  template<typename eT>
108  void Initialize(arma::Cube<eT>& W);
109 }; // class GlorotInitializationType
110 
111 template <>
112 template<typename eT>
113 inline void GlorotInitializationType<false>::Initialize(arma::Mat<eT>& W,
114  const size_t rows,
115  const size_t cols)
116 {
117  if (W.is_empty())
118  W.set_size(rows, cols);
119 
120  double var = 2.0 / double(rows + cols);
121  GaussianInitialization normalInit(0.0, var);
122  normalInit.Initialize(W, rows, cols);
123 }
124 
125 template <>
126 template<typename eT>
127 inline void GlorotInitializationType<false>::Initialize(arma::Mat<eT>& W)
128 {
129  if (W.is_empty())
130  Log::Fatal << "Cannot initialize and empty matrix." << std::endl;
131 
132  double var = 2.0 / double(W.n_rows + W.n_cols);
133  GaussianInitialization normalInit(0.0, var);
134  normalInit.Initialize(W);
135 }
136 
137 template <>
138 template<typename eT>
139 inline void GlorotInitializationType<true>::Initialize(arma::Mat<eT>& W,
140  const size_t rows,
141  const size_t cols)
142 {
143  if (W.is_empty())
144  W.set_size(rows, cols);
145 
146  // Limit of distribution.
147  double a = sqrt(6) / sqrt(rows + cols);
148  RandomInitialization randomInit(-a, a);
149  randomInit.Initialize(W, rows, cols);
150 }
151 
152 template <>
153 template<typename eT>
154 inline void GlorotInitializationType<true>::Initialize(arma::Mat<eT>& W)
155 {
156  if (W.is_empty())
157  Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
158 
159  // Limit of distribution.
160  double a = sqrt(6) / sqrt(W.n_rows + W.n_cols);
161  RandomInitialization randomInit(-a, a);
162  randomInit.Initialize(W);
163 }
164 
165 template <bool Uniform>
166 template<typename eT>
167 inline void GlorotInitializationType<Uniform>::Initialize(arma::Cube<eT>& W,
168  const size_t rows,
169  const size_t cols,
170  const size_t slices)
171 {
172  if (W.is_empty())
173  W.set_size(rows, cols, slices);
174 
175  for (size_t i = 0; i < slices; ++i)
176  Initialize(W.slice(i), rows, cols);
177 }
178 
179 template <bool Uniform>
180 template<typename eT>
181 inline void GlorotInitializationType<Uniform>::Initialize(arma::Cube<eT>& W)
182 {
183  if (W.is_empty())
184  Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
185 
186  for (size_t i = 0; i < W.n_slices; ++i)
187  Initialize(W.slice(i));
188 }
189 
190 // Convenience typedefs.
191 
196 
201 // Uses normal distribution
202 } // namespace ann
203 } // namespace mlpack
204 
205 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements weight matrix using a Gaussian Distribution.
This class is used to initialize randomly the weight matrix.
Definition: random_init.hpp:24
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize randomly the elements of the specified weight matrix.
Definition: random_init.hpp:56
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Miscellaneous math routines.
Definition: ccov.hpp:20
GlorotInitializationType()
Initialize the Glorot initialization object.
Definition: glorot_init.hpp:61
This class is used to initialize the weight matrix with the Glorot Initialization method...
Definition: glorot_init.hpp:55
This class is used to initialize weigth matrix with a gaussian.