15 #ifndef MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP 16 #define MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP 21 namespace naive_bayes {
57 template<
typename ModelMatType = arma::mat>
62 typedef typename ModelMatType::elem_type
ElemType;
83 template<
typename MatType>
85 const arma::Row<size_t>& labels,
86 const size_t numClasses,
87 const bool incrementalVariance =
false,
88 const double epsilon = 1e-10);
97 const size_t numClasses = 0,
98 const double epsilon = 1e-10);
117 template<
typename MatType>
118 void Train(
const MatType& data,
119 const arma::Row<size_t>& labels,
120 const size_t numClasses,
121 const bool incremental =
true);
131 template<
typename VecType>
132 void Train(
const VecType& point,
const size_t label);
140 template<
typename VecType>
141 size_t Classify(
const VecType& point)
const;
153 template<
typename VecType,
typename ProbabilitiesVecType>
156 ProbabilitiesVecType& probabilities)
const;
172 template<
typename MatType>
174 arma::Row<size_t>& predictions)
const;
197 template<
typename MatType,
typename ProbabilitiesMatType>
199 arma::Row<size_t>& predictions,
200 ProbabilitiesMatType& probabilities)
const;
203 const ModelMatType&
Means()
const {
return means; }
205 ModelMatType&
Means() {
return means; }
208 const ModelMatType&
Variances()
const {
return variances; }
218 template<
typename Archive>
219 void serialize(Archive& ar,
const uint32_t );
225 ModelMatType variances;
227 ModelMatType probabilities;
229 size_t trainingPoints;
241 template<
typename MatType>
242 void LogLikelihood(
const MatType& data,
243 ModelMatType& logLikelihoods)
const;
250 #include "naive_bayes_classifier_impl.hpp" ModelMatType & Probabilities()
Modify the prior probabilities for each class.
ModelMatType & Variances()
Modify the sample variances for each class.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
const ModelMatType & Variances() const
Get the sample variances for each class.
The simple Naive Bayes classifier.
size_t Classify(const VecType &point) const
Classify the given point, using the trained NaiveBayesClassifier model.
const ModelMatType & Probabilities() const
Get the prior probabilities for each class.
NaiveBayesClassifier(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const bool incrementalVariance=false, const double epsilon=1e-10)
Initializes the classifier as per the input and then trains it by calculating the sample mean and var...
void serialize(Archive &ar, const uint32_t)
Serialize the classifier.
ModelMatType::elem_type ElemType
const ModelMatType & Means() const
Get the sample means for each class.
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const bool incremental=true)
Train the Naive Bayes classifier on the given dataset.
ModelMatType & Means()
Modify the sample means for each class.