gan.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_METHODS_ANN_GAN_GAN_HPP
12 #define MLPACK_METHODS_ANN_GAN_GAN_HPP
13 
14 #include <mlpack/core.hpp>
15 
23 
24 
25 namespace mlpack {
26 namespace ann {
27 
57 template<
58  typename Model,
59  typename InitializationRuleType,
60  typename Noise,
61  typename PolicyType = StandardGAN
62 >
63 class GAN
64 {
65  public:
83  GAN(Model generator,
84  Model discriminator,
85  InitializationRuleType& initializeRule,
86  Noise& noiseFunction,
87  const size_t noiseDim,
88  const size_t batchSize,
89  const size_t generatorUpdateStep,
90  const size_t preTrainSize,
91  const double multiplier,
92  const double clippingParameter = 0.01,
93  const double lambda = 10.0);
94 
96  GAN(const GAN&);
97 
99  GAN(GAN&&);
100 
107  void ResetData(arma::mat trainData);
108 
109  // Reset function.
110  void Reset();
111 
123  template<typename OptimizerType, typename... CallbackTypes>
124  double Train(arma::mat trainData,
125  OptimizerType& Optimizer,
126  CallbackTypes&&... callbacks);
127 
137  template<typename Policy = PolicyType>
138  typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
139  std::is_same<Policy, DCGAN>::value, double>::type
140  Evaluate(const arma::mat& parameters,
141  const size_t i,
142  const size_t batchSize);
143 
152  template<typename Policy = PolicyType>
153  typename std::enable_if<std::is_same<Policy, WGAN>::value,
154  double>::type
155  Evaluate(const arma::mat& parameters,
156  const size_t i,
157  const size_t batchSize);
158 
167  template<typename Policy = PolicyType>
168  typename std::enable_if<std::is_same<Policy, WGANGP>::value,
169  double>::type
170  Evaluate(const arma::mat& parameters,
171  const size_t i,
172  const size_t batchSize);
173 
184  template<typename GradType, typename Policy = PolicyType>
185  typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
186  std::is_same<Policy, DCGAN>::value, double>::type
187  EvaluateWithGradient(const arma::mat& parameters,
188  const size_t i,
189  GradType& gradient,
190  const size_t batchSize);
191 
202  template<typename GradType, typename Policy = PolicyType>
203  typename std::enable_if<std::is_same<Policy, WGAN>::value,
204  double>::type
205  EvaluateWithGradient(const arma::mat& parameters,
206  const size_t i,
207  GradType& gradient,
208  const size_t batchSize);
209 
220  template<typename GradType, typename Policy = PolicyType>
221  typename std::enable_if<std::is_same<Policy, WGANGP>::value,
222  double>::type
223  EvaluateWithGradient(const arma::mat& parameters,
224  const size_t i,
225  GradType& gradient,
226  const size_t batchSize);
227 
238  template<typename Policy = PolicyType>
239  typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
240  std::is_same<Policy, DCGAN>::value, void>::type
241  Gradient(const arma::mat& parameters,
242  const size_t i,
243  arma::mat& gradient,
244  const size_t batchSize);
245 
256  template<typename Policy = PolicyType>
257  typename std::enable_if<std::is_same<Policy, WGAN>::value, void>::type
258  Gradient(const arma::mat& parameters,
259  const size_t i,
260  arma::mat& gradient,
261  const size_t batchSize);
262 
273  template<typename Policy = PolicyType>
274  typename std::enable_if<std::is_same<Policy, WGANGP>::value,
275  void>::type
276  Gradient(const arma::mat& parameters,
277  const size_t i,
278  arma::mat& gradient,
279  const size_t batchSize);
280 
285  void Shuffle();
286 
292  void Forward(const arma::mat& input);
293 
300  void Predict(arma::mat input, arma::mat& output);
301 
303  const arma::mat& Parameters() const { return parameter; }
305  arma::mat& Parameters() { return parameter; }
306 
308  const Model& Generator() const { return generator; }
310  Model& Generator() { return generator; }
312  const Model& Discriminator() const { return discriminator; }
314  Model& Discriminator() { return discriminator; }
315 
317  size_t NumFunctions() const { return numFunctions; }
318 
320  const arma::mat& Responses() const { return responses; }
322  arma::mat& Responses() { return responses; }
323 
325  const arma::mat& Predictors() const { return predictors; }
327  arma::mat& Predictors() { return predictors; }
328 
330  template<typename Archive>
331  void serialize(Archive& ar, const uint32_t /* version */);
332 
333  private:
338  void ResetDeterministic();
339 
341  arma::mat predictors;
343  arma::mat parameter;
345  Model generator;
347  Model discriminator;
349  InitializationRuleType initializeRule;
351  Noise noiseFunction;
353  size_t noiseDim;
355  size_t numFunctions;
357  size_t batchSize;
359  size_t currentBatch;
361  size_t generatorUpdateStep;
363  size_t preTrainSize;
365  double multiplier;
367  double clippingParameter;
369  double lambda;
371  bool reset;
373  DeltaVisitor deltaVisitor;
375  arma::mat responses;
377  arma::mat currentInput;
379  arma::mat currentTarget;
381  OutputParameterVisitor outputParameterVisitor;
383  WeightSizeVisitor weightSizeVisitor;
385  ResetVisitor resetVisitor;
387  arma::mat gradient;
389  arma::mat gradientDiscriminator;
391  arma::mat noiseGradientDiscriminator;
393  arma::mat normGradientDiscriminator;
395  arma::mat noise;
397  arma::mat gradientGenerator;
399  bool deterministic;
401  size_t genWeights;
403  size_t discWeights;
404 };
405 
406 } // namespace ann
407 } // namespace mlpack
408 
409 // Include implementation.
410 #include "gan_impl.hpp"
411 #include "wgan_impl.hpp"
412 #include "wgangp_impl.hpp"
413 
414 
415 #endif
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
const Model & Discriminator() const
Return the discriminator of the GAN.
Definition: gan.hpp:312
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: gan.hpp:317
Model & Generator()
Modify the generator of the GAN.
Definition: gan.hpp:310
void Forward(const arma::mat &input)
This function does a forward pass through the GAN network.
Linear algebra utility functions, generally performed on matrices or vectors.
void ResetData(arma::mat trainData)
Initialize the generator, discriminator and weights of the model for training.
WeightSizeVisitor returns the number of weights of the given module.
void Shuffle()
Shuffle the order of function visitation.
GAN(Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
arma::mat & Parameters()
Modify the parameters of the network.
Definition: gan.hpp:305
ResetVisitor executes the Reset() function.
OutputParameterVisitor exposes the output parameter of the given module.
const arma::mat & Responses() const
Get the matrix of responses to the input data points.
Definition: gan.hpp:320
double Train(arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
Train function.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
arma::mat & Predictors()
Modify the matrix of data points (predictors).
Definition: gan.hpp:327
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
const Model & Generator() const
Return the generator of the GAN.
Definition: gan.hpp:308
const arma::mat & Parameters() const
Return the parameters of the network.
Definition: gan.hpp:303
DeltaVisitor exposes the delta parameter of the given module.
The implementation of the standard GAN module.
Definition: gan.hpp:63
void Predict(arma::mat input, arma::mat &output)
This function predicts the output of the network on the given input.
arma::mat & Responses()
Modify the matrix of responses to the input data points.
Definition: gan.hpp:322
Model & Discriminator()
Modify the discriminator of the GAN.
Definition: gan.hpp:314
void serialize(Archive &ar, const uint32_t)
Serialize the model.
const arma::mat & Predictors() const
Get the matrix of data points (predictors).
Definition: gan.hpp:325