The implementation of the standard GAN module. More...
Public Member Functions | |
| GAN (Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0) | |
| Constructor for GAN class. More... | |
| GAN (const GAN &) | |
| Copy constructor. More... | |
| GAN (GAN &&) | |
| Move constructor. More... | |
| const Model & | Discriminator () const |
| Return the discriminator of the GAN. More... | |
| Model & | Discriminator () |
| Modify the discriminator of the GAN. More... | |
template < typename Policy = PolicyType > | |
| std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type | Evaluate (const arma::mat ¶meters, const size_t i, const size_t batchSize) |
| Evaluate function for the Standard GAN and DCGAN. More... | |
template < typename Policy = PolicyType > | |
| std::enable_if< std::is_same< Policy, WGAN >::value, double >::type | Evaluate (const arma::mat ¶meters, const size_t i, const size_t batchSize) |
| Evaluate function for the WGAN. More... | |
template < typename Policy = PolicyType > | |
| std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type | Evaluate (const arma::mat ¶meters, const size_t i, const size_t batchSize) |
| Evaluate function for the WGAN-GP. More... | |
template < typename GradType , typename Policy = PolicyType > | |
| std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type | EvaluateWithGradient (const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize) |
| EvaluateWithGradient function for the Standard GAN and DCGAN. More... | |
template < typename GradType , typename Policy = PolicyType > | |
| std::enable_if< std::is_same< Policy, WGAN >::value, double >::type | EvaluateWithGradient (const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize) |
| EvaluateWithGradient function for the WGAN. More... | |
template < typename GradType , typename Policy = PolicyType > | |
| std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type | EvaluateWithGradient (const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize) |
| EvaluateWithGradient function for the WGAN-GP. More... | |
| void | Forward (const arma::mat &input) |
| This function does a forward pass through the GAN network. More... | |
| const Model & | Generator () const |
| Return the generator of the GAN. More... | |
| Model & | Generator () |
| Modify the generator of the GAN. More... | |
template < typename Policy = PolicyType > | |
| std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type | Gradient (const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize) |
| Gradient function for Standard GAN and DCGAN. More... | |
template < typename Policy = PolicyType > | |
| std::enable_if< std::is_same< Policy, WGAN >::value, void >::type | Gradient (const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize) |
| Gradient function for WGAN. More... | |
template < typename Policy = PolicyType > | |
| std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type | Gradient (const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize) |
| Gradient function for WGAN-GP. More... | |
| size_t | NumFunctions () const |
| Return the number of separable functions (the number of predictor points). More... | |
| const arma::mat & | Parameters () const |
| Return the parameters of the network. More... | |
| arma::mat & | Parameters () |
| Modify the parameters of the network. More... | |
| void | Predict (arma::mat input, arma::mat &output) |
| This function predicts the output of the network on the given input. More... | |
| const arma::mat & | Predictors () const |
| Get the matrix of data points (predictors). More... | |
| arma::mat & | Predictors () |
| Modify the matrix of data points (predictors). More... | |
| void | Reset () |
| void | ResetData (arma::mat trainData) |
| Initialize the generator, discriminator and weights of the model for training. More... | |
| const arma::mat & | Responses () const |
| Get the matrix of responses to the input data points. More... | |
| arma::mat & | Responses () |
| Modify the matrix of responses to the input data points. More... | |
template < typename Archive > | |
| void | serialize (Archive &ar, const uint32_t) |
| Serialize the model. More... | |
| void | Shuffle () |
| Shuffle the order of function visitation. More... | |
| template<typename OptimizerType , typename... CallbackTypes> | |
| double | Train (arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks) |
| Train function. More... | |
The implementation of the standard GAN module.
Generative Adversarial Networks (GANs) are a class of artificial intelligence algorithms used in unsupervised machine learning, implemented by a system of two neural networks contesting with each other in a zero-sum game framework. This technique can generate photographs that look at least superficially authentic to human observers, having many realistic characteristics. GANs have been used in Text-to-Image Synthesis, Medical Drug Discovery, High Resolution Imagery Generation, Neural Machine Translation and so on.
For more information, see the following paper:
| GAN | ( | Model | generator, |
| Model | discriminator, | ||
| InitializationRuleType & | initializeRule, | ||
| Noise & | noiseFunction, | ||
| const size_t | noiseDim, | ||
| const size_t | batchSize, | ||
| const size_t | generatorUpdateStep, | ||
| const size_t | preTrainSize, | ||
| const double | multiplier, | ||
| const double | clippingParameter = 0.01, |
||
| const double | lambda = 10.0 |
||
| ) |
Constructor for GAN class.
| generator | Generator network. |
| discriminator | Discriminator network. |
| initializeRule | Initialization rule to use for initializing parameters. |
| noiseFunction | Function to be used for generating noise. |
| noiseDim | Dimension of noise vector to be created. |
| batchSize | Batch size to be used for training. |
| generatorUpdateStep | Number of steps to train Discriminator before updating Generator. |
| preTrainSize | Number of pre-training steps of Discriminator. |
| multiplier | Ratio of learning rate of Discriminator to the Generator. |
| clippingParameter | Weight range for enforcing Lipschitz constraint. |
| lambda | Parameter for setting the gradient penalty. |
|
inline |
|
inline |
| std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, double>::type Evaluate | ( | const arma::mat & | parameters, |
| const size_t | i, | ||
| const size_t | batchSize | ||
| ) |
| std::enable_if<std::is_same<Policy, WGAN>::value, double>::type Evaluate | ( | const arma::mat & | parameters, |
| const size_t | i, | ||
| const size_t | batchSize | ||
| ) |
| std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type Evaluate | ( | const arma::mat & | parameters, |
| const size_t | i, | ||
| const size_t | batchSize | ||
| ) |
Evaluate function for the WGAN-GP.
This function gives the performance of the WGAN-GP on the current input.
| parameters | The parameters of the network. |
| i | Index of the current input. |
| batchSize | Variable to store the present number of inputs. |
| std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, double>::type EvaluateWithGradient | ( | const arma::mat & | parameters, |
| const size_t | i, | ||
| GradType & | gradient, | ||
| const size_t | batchSize | ||
| ) |
EvaluateWithGradient function for the Standard GAN and DCGAN.
This function gives the performance of the Standard GAN or DCGAN on the current input, while updating Gradients.
| parameters | The parameters of the network. |
| i | Index of the current input. |
| gradient | Variable to store the present gradient. |
| batchSize | Variable to store the present number of inputs. |
| std::enable_if<std::is_same<Policy, WGAN>::value, double>::type EvaluateWithGradient | ( | const arma::mat & | parameters, |
| const size_t | i, | ||
| GradType & | gradient, | ||
| const size_t | batchSize | ||
| ) |
EvaluateWithGradient function for the WGAN.
This function gives the performance of the WGAN on the current input, while updating Gradients.
| parameters | The parameters of the network. |
| i | Index of the current input. |
| gradient | Variable to store the present gradient. |
| batchSize | Variable to store the present number of inputs. |
| std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type EvaluateWithGradient | ( | const arma::mat & | parameters, |
| const size_t | i, | ||
| GradType & | gradient, | ||
| const size_t | batchSize | ||
| ) |
EvaluateWithGradient function for the WGAN-GP.
This function gives the performance of the WGAN-GP on the current input, while updating Gradients.
| parameters | The parameters of the network. |
| i | Index of the current input. |
| gradient | Variable to store the present gradient. |
| batchSize | Variable to store the present number of inputs. |
| void Forward | ( | const arma::mat & | input | ) |
This function does a forward pass through the GAN network.
| input | Sampled noise. |
|
inline |
|
inline |
| std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, void>::type Gradient | ( | const arma::mat & | parameters, |
| const size_t | i, | ||
| arma::mat & | gradient, | ||
| const size_t | batchSize | ||
| ) |
Gradient function for Standard GAN and DCGAN.
This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.
| parameters | present parameters of the network. |
| i | Index of the predictors. |
| gradient | Variable to store the present gradient. |
| batchSize | Variable to store the present number of inputs. |
| std::enable_if<std::is_same<Policy, WGAN>::value, void>::type Gradient | ( | const arma::mat & | parameters, |
| const size_t | i, | ||
| arma::mat & | gradient, | ||
| const size_t | batchSize | ||
| ) |
Gradient function for WGAN.
This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.
| parameters | present parameters of the network. |
| i | Index of the predictors. |
| gradient | Variable to store the present gradient. |
| batchSize | Variable to store the present number of inputs. |
| std::enable_if<std::is_same<Policy, WGANGP>::value, void>::type Gradient | ( | const arma::mat & | parameters, |
| const size_t | i, | ||
| arma::mat & | gradient, | ||
| const size_t | batchSize | ||
| ) |
Gradient function for WGAN-GP.
This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.
| parameters | present parameters of the network. |
| i | Index of the predictors. |
| gradient | Variable to store the present gradient. |
| batchSize | Variable to store the present number of inputs. |
|
inline |
|
inline |
|
inline |
| void Predict | ( | arma::mat | input, |
| arma::mat & | output | ||
| ) |
This function predicts the output of the network on the given input.
| input | The input of the Generator network. |
| output | Result of the Discriminator network. |
|
inline |
|
inline |
Modify the matrix of data points (predictors).
Definition at line 327 of file gan.hpp.
References GAN< Model, InitializationRuleType, Noise, PolicyType >::serialize().
| void Reset | ( | ) |
| void ResetData | ( | arma::mat | trainData | ) |
Initialize the generator, discriminator and weights of the model for training.
This function won't actually trigger training process.
| trainData | The data points of real distribution. |
|
inline |
|
inline |
| void serialize | ( | Archive & | ar, |
| const uint32_t | |||
| ) |
Serialize the model.
Referenced by GAN< Model, InitializationRuleType, Noise, PolicyType >::Predictors().
| void Shuffle | ( | ) |
Shuffle the order of function visitation.
This may be called by the optimizer.
| double Train | ( | arma::mat | trainData, |
| OptimizerType & | Optimizer, | ||
| CallbackTypes &&... | callbacks | ||
| ) |
Train function.
| OptimizerType | Type of optimizer to use to train the model. |
| CallbackTypes | Types of Callback functions. |
| trainData | The data points of real distribution. |
| Optimizer | Instantiated optimizer used to train the model. |
| callbacks | Callback function for ensmallen optimizer OptimizerType. See https://www.ensmallen.org/docs.html#callback-documentation. |