naive_bayes_classifier.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
16 #define MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
17 
18 #include <mlpack/prereqs.hpp>
19 
20 namespace mlpack {
21 namespace naive_bayes {
22 
57 template<typename ModelMatType = arma::mat>
59 {
60  public:
61  // Convenience typedef.
62  typedef typename ModelMatType::elem_type ElemType;
63 
83  template<typename MatType>
84  NaiveBayesClassifier(const MatType& data,
85  const arma::Row<size_t>& labels,
86  const size_t numClasses,
87  const bool incrementalVariance = false,
88  const double epsilon = 1e-10);
89 
96  NaiveBayesClassifier(const size_t dimensionality = 0,
97  const size_t numClasses = 0,
98  const double epsilon = 1e-10);
99 
117  template<typename MatType>
118  void Train(const MatType& data,
119  const arma::Row<size_t>& labels,
120  const size_t numClasses,
121  const bool incremental = true);
122 
131  template<typename VecType>
132  void Train(const VecType& point, const size_t label);
133 
140  template<typename VecType>
141  size_t Classify(const VecType& point) const;
142 
153  template<typename VecType, typename ProbabilitiesVecType>
154  void Classify(const VecType& point,
155  size_t& prediction,
156  ProbabilitiesVecType& probabilities) const;
157 
172  template<typename MatType>
173  void Classify(const MatType& data,
174  arma::Row<size_t>& predictions) const;
175 
197  template<typename MatType, typename ProbabilitiesMatType>
198  void Classify(const MatType& data,
199  arma::Row<size_t>& predictions,
200  ProbabilitiesMatType& probabilities) const;
201 
203  const ModelMatType& Means() const { return means; }
205  ModelMatType& Means() { return means; }
206 
208  const ModelMatType& Variances() const { return variances; }
210  ModelMatType& Variances() { return variances; }
211 
213  const ModelMatType& Probabilities() const { return probabilities; }
215  ModelMatType& Probabilities() { return probabilities; }
216 
218  template<typename Archive>
219  void serialize(Archive& ar, const uint32_t /* version */);
220 
221  private:
223  ModelMatType means;
225  ModelMatType variances;
227  ModelMatType probabilities;
229  size_t trainingPoints;
231  double epsilon;
232 
241  template<typename MatType>
242  void LogLikelihood(const MatType& data,
243  ModelMatType& logLikelihoods) const;
244 };
245 
246 } // namespace naive_bayes
247 } // namespace mlpack
248 
249 // Include implementation.
250 #include "naive_bayes_classifier_impl.hpp"
251 
252 #endif
ModelMatType & Probabilities()
Modify the prior probabilities for each class.
ModelMatType & Variances()
Modify the sample variances for each class.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
const ModelMatType & Variances() const
Get the sample variances for each class.
The simple Naive Bayes classifier.
size_t Classify(const VecType &point) const
Classify the given point, using the trained NaiveBayesClassifier model.
const ModelMatType & Probabilities() const
Get the prior probabilities for each class.
NaiveBayesClassifier(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const bool incrementalVariance=false, const double epsilon=1e-10)
Initializes the classifier as per the input and then trains it by calculating the sample mean and var...
void serialize(Archive &ar, const uint32_t)
Serialize the classifier.
const ModelMatType & Means() const
Get the sample means for each class.
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const bool incremental=true)
Train the Naive Bayes classifier on the given dataset.
ModelMatType & Means()
Modify the sample means for each class.