standard_scaler.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_DATA_STANDARD_SCALE_HPP
13 #define MLPACK_CORE_DATA_STANDARD_SCALE_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace data {
19 
48 {
49  public:
55  template<typename MatType>
56  void Fit(const MatType& input)
57  {
58  itemMean = arma::mean(input, 1);
59  itemStdDev = arma::stddev(input, 1, 1);
60  // Handle zeros in scale vector.
61  itemStdDev.for_each([](arma::vec::elem_type& val) { val =
62  (val == 0) ? 1 : val; });
63  }
64 
71  template<typename MatType>
72  void Transform(const MatType& input, MatType& output)
73  {
74  if (itemMean.is_empty() || itemStdDev.is_empty())
75  {
76  throw std::runtime_error("Call Fit() before Transform(), please"
77  " refer to the documentation.");
78  }
79  output.copy_size(input);
80  output = (input.each_col() - itemMean).each_col() / itemStdDev;
81  }
82 
89  template<typename MatType>
90  void InverseTransform(const MatType& input, MatType& output)
91  {
92  output.copy_size(input);
93  output = (input.each_col() % itemStdDev).each_col() + itemMean;
94  }
95 
97  const arma::vec& ItemMean() const { return itemMean; }
99  const arma::vec& ItemStdDev() const { return itemStdDev; }
100 
101  template<typename Archive>
102  void serialize(Archive& ar, const uint32_t /* version */)
103  {
104  ar(CEREAL_NVP(itemMean));
105  ar(CEREAL_NVP(itemStdDev));
106  }
107 
108  private:
109  // Vector which holds mean of each feature.
110  arma::vec itemMean;
111  // Vector which holds standard devation of each feature.
112  arma::vec itemStdDev;
113 }; // class StandardScaler
114 
115 } // namespace data
116 } // namespace mlpack
117 
118 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
const arma::vec & ItemMean() const
Get the mean row vector.
void Fit(const MatType &input)
Function to fit features, to find out the min max and scale.
const arma::vec & ItemStdDev() const
Get the standard deviation row vector.
A simple Standard Scaler class.
void Transform(const MatType &input, MatType &output)
Function to scale features.
void InverseTransform(const MatType &input, MatType &output)
Function to retrieve original dataset.
void serialize(Archive &ar, const uint32_t)