The implementation of the standard GAN module. More...
Public Member Functions | |
GAN (Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0) | |
Constructor for GAN class. More... | |
GAN (const GAN &) | |
Copy constructor. More... | |
GAN (GAN &&) | |
Move constructor. More... | |
const Model & | Discriminator () const |
Return the discriminator of the GAN. More... | |
Model & | Discriminator () |
Modify the discriminator of the GAN. More... | |
template < typename Policy = PolicyType > | |
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type | Evaluate (const arma::mat ¶meters, const size_t i, const size_t batchSize) |
Evaluate function for the Standard GAN and DCGAN. More... | |
template < typename Policy = PolicyType > | |
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type | Evaluate (const arma::mat ¶meters, const size_t i, const size_t batchSize) |
Evaluate function for the WGAN. More... | |
template < typename Policy = PolicyType > | |
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type | Evaluate (const arma::mat ¶meters, const size_t i, const size_t batchSize) |
Evaluate function for the WGAN-GP. More... | |
template < typename GradType , typename Policy = PolicyType > | |
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type | EvaluateWithGradient (const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize) |
EvaluateWithGradient function for the Standard GAN and DCGAN. More... | |
template < typename GradType , typename Policy = PolicyType > | |
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type | EvaluateWithGradient (const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize) |
EvaluateWithGradient function for the WGAN. More... | |
template < typename GradType , typename Policy = PolicyType > | |
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type | EvaluateWithGradient (const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize) |
EvaluateWithGradient function for the WGAN-GP. More... | |
void | Forward (const arma::mat &input) |
This function does a forward pass through the GAN network. More... | |
const Model & | Generator () const |
Return the generator of the GAN. More... | |
Model & | Generator () |
Modify the generator of the GAN. More... | |
template < typename Policy = PolicyType > | |
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type | Gradient (const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize) |
Gradient function for Standard GAN and DCGAN. More... | |
template < typename Policy = PolicyType > | |
std::enable_if< std::is_same< Policy, WGAN >::value, void >::type | Gradient (const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize) |
Gradient function for WGAN. More... | |
template < typename Policy = PolicyType > | |
std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type | Gradient (const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize) |
Gradient function for WGAN-GP. More... | |
size_t | NumFunctions () const |
Return the number of separable functions (the number of predictor points). More... | |
const arma::mat & | Parameters () const |
Return the parameters of the network. More... | |
arma::mat & | Parameters () |
Modify the parameters of the network. More... | |
void | Predict (arma::mat input, arma::mat &output) |
This function predicts the output of the network on the given input. More... | |
const arma::mat & | Predictors () const |
Get the matrix of data points (predictors). More... | |
arma::mat & | Predictors () |
Modify the matrix of data points (predictors). More... | |
void | Reset () |
void | ResetData (arma::mat trainData) |
Initialize the generator, discriminator and weights of the model for training. More... | |
const arma::mat & | Responses () const |
Get the matrix of responses to the input data points. More... | |
arma::mat & | Responses () |
Modify the matrix of responses to the input data points. More... | |
template < typename Archive > | |
void | serialize (Archive &ar, const uint32_t) |
Serialize the model. More... | |
void | Shuffle () |
Shuffle the order of function visitation. More... | |
template<typename OptimizerType , typename... CallbackTypes> | |
double | Train (arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks) |
Train function. More... | |
The implementation of the standard GAN module.
Generative Adversarial Networks (GANs) are a class of artificial intelligence algorithms used in unsupervised machine learning, implemented by a system of two neural networks contesting with each other in a zero-sum game framework. This technique can generate photographs that look at least superficially authentic to human observers, having many realistic characteristics. GANs have been used in Text-to-Image Synthesis, Medical Drug Discovery, High Resolution Imagery Generation, Neural Machine Translation and so on.
For more information, see the following paper:
GAN | ( | Model | generator, |
Model | discriminator, | ||
InitializationRuleType & | initializeRule, | ||
Noise & | noiseFunction, | ||
const size_t | noiseDim, | ||
const size_t | batchSize, | ||
const size_t | generatorUpdateStep, | ||
const size_t | preTrainSize, | ||
const double | multiplier, | ||
const double | clippingParameter = 0.01 , |
||
const double | lambda = 10.0 |
||
) |
Constructor for GAN class.
generator | Generator network. |
discriminator | Discriminator network. |
initializeRule | Initialization rule to use for initializing parameters. |
noiseFunction | Function to be used for generating noise. |
noiseDim | Dimension of noise vector to be created. |
batchSize | Batch size to be used for training. |
generatorUpdateStep | Number of steps to train Discriminator before updating Generator. |
preTrainSize | Number of pre-training steps of Discriminator. |
multiplier | Ratio of learning rate of Discriminator to the Generator. |
clippingParameter | Weight range for enforcing Lipschitz constraint. |
lambda | Parameter for setting the gradient penalty. |
|
inline |
|
inline |
std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, double>::type Evaluate | ( | const arma::mat & | parameters, |
const size_t | i, | ||
const size_t | batchSize | ||
) |
std::enable_if<std::is_same<Policy, WGAN>::value, double>::type Evaluate | ( | const arma::mat & | parameters, |
const size_t | i, | ||
const size_t | batchSize | ||
) |
std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type Evaluate | ( | const arma::mat & | parameters, |
const size_t | i, | ||
const size_t | batchSize | ||
) |
Evaluate function for the WGAN-GP.
This function gives the performance of the WGAN-GP on the current input.
parameters | The parameters of the network. |
i | Index of the current input. |
batchSize | Variable to store the present number of inputs. |
std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, double>::type EvaluateWithGradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
GradType & | gradient, | ||
const size_t | batchSize | ||
) |
EvaluateWithGradient function for the Standard GAN and DCGAN.
This function gives the performance of the Standard GAN or DCGAN on the current input, while updating Gradients.
parameters | The parameters of the network. |
i | Index of the current input. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
std::enable_if<std::is_same<Policy, WGAN>::value, double>::type EvaluateWithGradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
GradType & | gradient, | ||
const size_t | batchSize | ||
) |
EvaluateWithGradient function for the WGAN.
This function gives the performance of the WGAN on the current input, while updating Gradients.
parameters | The parameters of the network. |
i | Index of the current input. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type EvaluateWithGradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
GradType & | gradient, | ||
const size_t | batchSize | ||
) |
EvaluateWithGradient function for the WGAN-GP.
This function gives the performance of the WGAN-GP on the current input, while updating Gradients.
parameters | The parameters of the network. |
i | Index of the current input. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
void Forward | ( | const arma::mat & | input | ) |
This function does a forward pass through the GAN network.
input | Sampled noise. |
|
inline |
|
inline |
std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, void>::type Gradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
arma::mat & | gradient, | ||
const size_t | batchSize | ||
) |
Gradient function for Standard GAN and DCGAN.
This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.
parameters | present parameters of the network. |
i | Index of the predictors. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
std::enable_if<std::is_same<Policy, WGAN>::value, void>::type Gradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
arma::mat & | gradient, | ||
const size_t | batchSize | ||
) |
Gradient function for WGAN.
This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.
parameters | present parameters of the network. |
i | Index of the predictors. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
std::enable_if<std::is_same<Policy, WGANGP>::value, void>::type Gradient | ( | const arma::mat & | parameters, |
const size_t | i, | ||
arma::mat & | gradient, | ||
const size_t | batchSize | ||
) |
Gradient function for WGAN-GP.
This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.
parameters | present parameters of the network. |
i | Index of the predictors. |
gradient | Variable to store the present gradient. |
batchSize | Variable to store the present number of inputs. |
|
inline |
|
inline |
|
inline |
void Predict | ( | arma::mat | input, |
arma::mat & | output | ||
) |
This function predicts the output of the network on the given input.
input | The input of the Generator network. |
output | Result of the Discriminator network. |
|
inline |
|
inline |
Modify the matrix of data points (predictors).
Definition at line 327 of file gan.hpp.
References GAN< Model, InitializationRuleType, Noise, PolicyType >::serialize().
void Reset | ( | ) |
void ResetData | ( | arma::mat | trainData | ) |
Initialize the generator, discriminator and weights of the model for training.
This function won't actually trigger training process.
trainData | The data points of real distribution. |
|
inline |
|
inline |
void serialize | ( | Archive & | ar, |
const uint32_t | |||
) |
Serialize the model.
Referenced by GAN< Model, InitializationRuleType, Noise, PolicyType >::Predictors().
void Shuffle | ( | ) |
Shuffle the order of function visitation.
This may be called by the optimizer.
double Train | ( | arma::mat | trainData, |
OptimizerType & | Optimizer, | ||
CallbackTypes &&... | callbacks | ||
) |
Train function.
OptimizerType | Type of optimizer to use to train the model. |
CallbackTypes | Types of Callback functions. |
trainData | The data points of real distribution. |
Optimizer | Instantiated optimizer used to train the model. |
callbacks | Callback function for ensmallen optimizer OptimizerType . See https://www.ensmallen.org/docs.html#callback-documentation. |