11 #ifndef MLPACK_METHODS_ANN_GAN_GAN_HPP    12 #define MLPACK_METHODS_ANN_GAN_GAN_HPP    59   typename InitializationRuleType,
    61   typename PolicyType = StandardGAN
    85       InitializationRuleType& initializeRule,
    87       const size_t noiseDim,
    88       const size_t batchSize,
    89       const size_t generatorUpdateStep,
    90       const size_t preTrainSize,
    91       const double multiplier,
    92       const double clippingParameter = 0.01,
    93       const double lambda = 10.0);
   123   template<
typename OptimizerType, 
typename... CallbackTypes>
   124   double Train(arma::mat trainData,
   125                OptimizerType& Optimizer,
   126                CallbackTypes&&... callbacks);
   137   template<
typename Policy = PolicyType>
   138   typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
   139                           std::is_same<Policy, DCGAN>::value, 
double>::type
   140   Evaluate(
const arma::mat& parameters,
   142            const size_t batchSize);
   152   template<
typename Policy = PolicyType>
   153   typename std::enable_if<std::is_same<Policy, WGAN>::value,
   155   Evaluate(
const arma::mat& parameters,
   157            const size_t batchSize);
   167   template<
typename Policy = PolicyType>
   168   typename std::enable_if<std::is_same<Policy, WGANGP>::value,
   170   Evaluate(
const arma::mat& parameters,
   172            const size_t batchSize);
   184   template<
typename GradType, 
typename Policy = PolicyType>
   185   typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
   186                           std::is_same<Policy, DCGAN>::value, 
double>::type
   190                        const size_t batchSize);
   202   template<
typename GradType, 
typename Policy = PolicyType>
   203   typename std::enable_if<std::is_same<Policy, WGAN>::value,
   208                        const size_t batchSize);
   220   template<
typename GradType, 
typename Policy = PolicyType>
   221   typename std::enable_if<std::is_same<Policy, WGANGP>::value,
   226                        const size_t batchSize);
   238   template<
typename Policy = PolicyType>
   239   typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
   240                           std::is_same<Policy, DCGAN>::value, 
void>::type
   241   Gradient(
const arma::mat& parameters,
   244            const size_t batchSize);
   256   template<
typename Policy = PolicyType>
   257   typename std::enable_if<std::is_same<Policy, WGAN>::value, 
void>::type
   258   Gradient(
const arma::mat& parameters,
   261            const size_t batchSize);
   273   template<
typename Policy = PolicyType>
   274   typename std::enable_if<std::is_same<Policy, WGANGP>::value,
   276   Gradient(
const arma::mat& parameters,
   279            const size_t batchSize);
   292   void Forward(
const arma::mat& input);
   300   void Predict(arma::mat input, arma::mat& output);
   320   const arma::mat& 
Responses()
 const { 
return responses; }
   330   template<
typename Archive>
   331   void serialize(Archive& ar, 
const uint32_t );
   338   void ResetDeterministic();
   341   arma::mat predictors;
   349   InitializationRuleType initializeRule;
   361   size_t generatorUpdateStep;
   367   double clippingParameter;
   377   arma::mat currentInput;
   379   arma::mat currentTarget;
   389   arma::mat gradientDiscriminator;
   391   arma::mat noiseGradientDiscriminator;
   393   arma::mat normGradientDiscriminator;
   397   arma::mat gradientGenerator;
   410 #include "gan_impl.hpp"   411 #include "wgan_impl.hpp"   412 #include "wgangp_impl.hpp" std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN. 
 
const Model & Discriminator() const
Return the discriminator of the GAN. 
 
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points). 
 
Model & Generator()
Modify the generator of the GAN. 
 
void Forward(const arma::mat &input)
This function does a forward pass through the GAN network. 
 
Linear algebra utility functions, generally performed on matrices or vectors. 
 
void ResetData(arma::mat trainData)
Initialize the generator, discriminator and weights of the model for training. 
 
WeightSizeVisitor returns the number of weights of the given module. 
 
void Shuffle()
Shuffle the order of function visitation. 
 
GAN(Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class. 
 
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN. 
 
arma::mat & Parameters()
Modify the parameters of the network. 
 
ResetVisitor executes the Reset() function. 
 
OutputParameterVisitor exposes the output parameter of the given module. 
 
const arma::mat & Responses() const
Get the matrix of responses to the input data points. 
 
double Train(arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
Train function. 
 
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN. 
 
arma::mat & Predictors()
Modify the matrix of data points (predictors). 
 
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
 
const Model & Generator() const
Return the generator of the GAN. 
 
const arma::mat & Parameters() const
Return the parameters of the network. 
 
DeltaVisitor exposes the delta parameter of the given module. 
 
The implementation of the standard GAN module. 
 
void Predict(arma::mat input, arma::mat &output)
This function predicts the output of the network on the given input. 
 
arma::mat & Responses()
Modify the matrix of responses to the input data points. 
 
Model & Discriminator()
Modify the discriminator of the GAN. 
 
void serialize(Archive &ar, const uint32_t)
Serialize the model. 
 
const arma::mat & Predictors() const
Get the matrix of data points (predictors).