11 #ifndef MLPACK_METHODS_ANN_GAN_GAN_HPP 12 #define MLPACK_METHODS_ANN_GAN_GAN_HPP 59 typename InitializationRuleType,
61 typename PolicyType = StandardGAN
85 InitializationRuleType& initializeRule,
87 const size_t noiseDim,
88 const size_t batchSize,
89 const size_t generatorUpdateStep,
90 const size_t preTrainSize,
91 const double multiplier,
92 const double clippingParameter = 0.01,
93 const double lambda = 10.0);
123 template<
typename OptimizerType,
typename... CallbackTypes>
124 double Train(arma::mat trainData,
125 OptimizerType& Optimizer,
126 CallbackTypes&&... callbacks);
137 template<
typename Policy = PolicyType>
138 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
139 std::is_same<Policy, DCGAN>::value,
double>::type
140 Evaluate(
const arma::mat& parameters,
142 const size_t batchSize);
152 template<
typename Policy = PolicyType>
153 typename std::enable_if<std::is_same<Policy, WGAN>::value,
155 Evaluate(
const arma::mat& parameters,
157 const size_t batchSize);
167 template<
typename Policy = PolicyType>
168 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
170 Evaluate(
const arma::mat& parameters,
172 const size_t batchSize);
184 template<
typename GradType,
typename Policy = PolicyType>
185 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
186 std::is_same<Policy, DCGAN>::value,
double>::type
190 const size_t batchSize);
202 template<
typename GradType,
typename Policy = PolicyType>
203 typename std::enable_if<std::is_same<Policy, WGAN>::value,
208 const size_t batchSize);
220 template<
typename GradType,
typename Policy = PolicyType>
221 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
226 const size_t batchSize);
238 template<
typename Policy = PolicyType>
239 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
240 std::is_same<Policy, DCGAN>::value,
void>::type
241 Gradient(
const arma::mat& parameters,
244 const size_t batchSize);
256 template<
typename Policy = PolicyType>
257 typename std::enable_if<std::is_same<Policy, WGAN>::value,
void>::type
258 Gradient(
const arma::mat& parameters,
261 const size_t batchSize);
273 template<
typename Policy = PolicyType>
274 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
276 Gradient(
const arma::mat& parameters,
279 const size_t batchSize);
292 void Forward(
const arma::mat& input);
300 void Predict(arma::mat input, arma::mat& output);
320 const arma::mat&
Responses()
const {
return responses; }
330 template<
typename Archive>
331 void serialize(Archive& ar,
const uint32_t );
338 void ResetDeterministic();
341 arma::mat predictors;
349 InitializationRuleType initializeRule;
361 size_t generatorUpdateStep;
367 double clippingParameter;
377 arma::mat currentInput;
379 arma::mat currentTarget;
389 arma::mat gradientDiscriminator;
391 arma::mat noiseGradientDiscriminator;
393 arma::mat normGradientDiscriminator;
397 arma::mat gradientGenerator;
410 #include "gan_impl.hpp" 411 #include "wgan_impl.hpp" 412 #include "wgangp_impl.hpp" std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
const Model & Discriminator() const
Return the discriminator of the GAN.
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Model & Generator()
Modify the generator of the GAN.
void Forward(const arma::mat &input)
This function does a forward pass through the GAN network.
Linear algebra utility functions, generally performed on matrices or vectors.
void ResetData(arma::mat trainData)
Initialize the generator, discriminator and weights of the model for training.
WeightSizeVisitor returns the number of weights of the given module.
void Shuffle()
Shuffle the order of function visitation.
GAN(Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
arma::mat & Parameters()
Modify the parameters of the network.
ResetVisitor executes the Reset() function.
OutputParameterVisitor exposes the output parameter of the given module.
const arma::mat & Responses() const
Get the matrix of responses to the input data points.
double Train(arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
Train function.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
arma::mat & Predictors()
Modify the matrix of data points (predictors).
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
const Model & Generator() const
Return the generator of the GAN.
const arma::mat & Parameters() const
Return the parameters of the network.
DeltaVisitor exposes the delta parameter of the given module.
The implementation of the standard GAN module.
void Predict(arma::mat input, arma::mat &output)
This function predicts the output of the network on the given input.
arma::mat & Responses()
Modify the matrix of responses to the input data points.
Model & Discriminator()
Modify the discriminator of the GAN.
void serialize(Archive &ar, const uint32_t)
Serialize the model.
const arma::mat & Predictors() const
Get the matrix of data points (predictors).