GAN< Model, InitializationRuleType, Noise, PolicyType > Class Template Reference

The implementation of the standard GAN module. More...

Public Member Functions

 GAN (Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
 Constructor for GAN class. More...

 
 GAN (const GAN &)
 Copy constructor. More...

 
 GAN (GAN &&)
 Move constructor. More...

 
const Model & Discriminator () const
 Return the discriminator of the GAN. More...

 
Model & Discriminator ()
 Modify the discriminator of the GAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate (const arma::mat &parameters, const size_t i, const size_t batchSize)
 Evaluate function for the Standard GAN and DCGAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type Evaluate (const arma::mat &parameters, const size_t i, const size_t batchSize)
 Evaluate function for the WGAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type Evaluate (const arma::mat &parameters, const size_t i, const size_t batchSize)
 Evaluate function for the WGAN-GP. More...

 
template
<
typename
GradType
,
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient (const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
 EvaluateWithGradient function for the Standard GAN and DCGAN. More...

 
template
<
typename
GradType
,
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type EvaluateWithGradient (const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
 EvaluateWithGradient function for the WGAN. More...

 
template
<
typename
GradType
,
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type EvaluateWithGradient (const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
 EvaluateWithGradient function for the WGAN-GP. More...

 
void Forward (const arma::mat &input)
 This function does a forward pass through the GAN network. More...

 
const Model & Generator () const
 Return the generator of the GAN. More...

 
Model & Generator ()
 Modify the generator of the GAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient (const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
 Gradient function for Standard GAN and DCGAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGAN >::value, void >::type Gradient (const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
 Gradient function for WGAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type Gradient (const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
 Gradient function for WGAN-GP. More...

 
size_t NumFunctions () const
 Return the number of separable functions (the number of predictor points). More...

 
const arma::mat & Parameters () const
 Return the parameters of the network. More...

 
arma::mat & Parameters ()
 Modify the parameters of the network. More...

 
void Predict (arma::mat input, arma::mat &output)
 This function predicts the output of the network on the given input. More...

 
const arma::mat & Predictors () const
 Get the matrix of data points (predictors). More...

 
arma::mat & Predictors ()
 Modify the matrix of data points (predictors). More...

 
void Reset ()
 
void ResetData (arma::mat trainData)
 Initialize the generator, discriminator and weights of the model for training. More...

 
const arma::mat & Responses () const
 Get the matrix of responses to the input data points. More...

 
arma::mat & Responses ()
 Modify the matrix of responses to the input data points. More...

 
template
<
typename
Archive
>
void serialize (Archive &ar, const uint32_t)
 Serialize the model. More...

 
void Shuffle ()
 Shuffle the order of function visitation. More...

 
template<typename OptimizerType , typename... CallbackTypes>
double Train (arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
 Train function. More...

 

Detailed Description


template
<
typename
Model
,
typename
InitializationRuleType
,
typename
Noise
,
typename
PolicyType
=
StandardGAN
>

class mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >

The implementation of the standard GAN module.

Generative Adversarial Networks (GANs) are a class of artificial intelligence algorithms used in unsupervised machine learning, implemented by a system of two neural networks contesting with each other in a zero-sum game framework. This technique can generate photographs that look at least superficially authentic to human observers, having many realistic characteristics. GANs have been used in Text-to-Image Synthesis, Medical Drug Discovery, High Resolution Imagery Generation, Neural Machine Translation and so on.

For more information, see the following paper:

@article{Goodfellow14,
author = {Ian J. Goodfellow, Jean Pouget-Abadi, Mehdi Mirza, Bing Xu,
David Warde-Farley, Sherjil Ozair, Aaron Courville and
Yoshua Bengio},
title = {Generative Adversarial Nets},
year = {2014},
url = {http://arxiv.org/abs/1406.2661},
eprint = {1406.2661},
}
Template Parameters
ModelThe class type of Generator and Discriminator.
InitializationRuleTypeType of Initializer.
NoiseThe noise function to use.
PolicyTypeThe GAN variant to be used (GAN, DCGAN, WGAN or WGANGP).

Definition at line 63 of file gan.hpp.

Constructor & Destructor Documentation

◆ GAN() [1/3]

GAN ( Model  generator,
Model  discriminator,
InitializationRuleType &  initializeRule,
Noise &  noiseFunction,
const size_t  noiseDim,
const size_t  batchSize,
const size_t  generatorUpdateStep,
const size_t  preTrainSize,
const double  multiplier,
const double  clippingParameter = 0.01,
const double  lambda = 10.0 
)

Constructor for GAN class.

Parameters
generatorGenerator network.
discriminatorDiscriminator network.
initializeRuleInitialization rule to use for initializing parameters.
noiseFunctionFunction to be used for generating noise.
noiseDimDimension of noise vector to be created.
batchSizeBatch size to be used for training.
generatorUpdateStepNumber of steps to train Discriminator before updating Generator.
preTrainSizeNumber of pre-training steps of Discriminator.
multiplierRatio of learning rate of Discriminator to the Generator.
clippingParameterWeight range for enforcing Lipschitz constraint.
lambdaParameter for setting the gradient penalty.

◆ GAN() [2/3]

GAN ( const GAN< Model, InitializationRuleType, Noise, PolicyType > &  )

Copy constructor.

◆ GAN() [3/3]

GAN ( GAN< Model, InitializationRuleType, Noise, PolicyType > &&  )

Move constructor.

Member Function Documentation

◆ Discriminator() [1/2]

const Model& Discriminator ( ) const
inline

Return the discriminator of the GAN.

Definition at line 312 of file gan.hpp.

◆ Discriminator() [2/2]

Model& Discriminator ( )
inline

Modify the discriminator of the GAN.

Definition at line 314 of file gan.hpp.

◆ Evaluate() [1/3]

std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, double>::type Evaluate ( const arma::mat &  parameters,
const size_t  i,
const size_t  batchSize 
)

Evaluate function for the Standard GAN and DCGAN.

This function gives the performance of the Standard GAN or DCGAN on the current input.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
batchSizeVariable to store the present number of inputs.

◆ Evaluate() [2/3]

std::enable_if<std::is_same<Policy, WGAN>::value, double>::type Evaluate ( const arma::mat &  parameters,
const size_t  i,
const size_t  batchSize 
)

Evaluate function for the WGAN.

This function gives the performance of the WGAN on the current input.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
batchSizeVariable to store the present number of inputs.

◆ Evaluate() [3/3]

std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type Evaluate ( const arma::mat &  parameters,
const size_t  i,
const size_t  batchSize 
)

Evaluate function for the WGAN-GP.

This function gives the performance of the WGAN-GP on the current input.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
batchSizeVariable to store the present number of inputs.

◆ EvaluateWithGradient() [1/3]

std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, double>::type EvaluateWithGradient ( const arma::mat &  parameters,
const size_t  i,
GradType &  gradient,
const size_t  batchSize 
)

EvaluateWithGradient function for the Standard GAN and DCGAN.

This function gives the performance of the Standard GAN or DCGAN on the current input, while updating Gradients.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ EvaluateWithGradient() [2/3]

std::enable_if<std::is_same<Policy, WGAN>::value, double>::type EvaluateWithGradient ( const arma::mat &  parameters,
const size_t  i,
GradType &  gradient,
const size_t  batchSize 
)

EvaluateWithGradient function for the WGAN.

This function gives the performance of the WGAN on the current input, while updating Gradients.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ EvaluateWithGradient() [3/3]

std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type EvaluateWithGradient ( const arma::mat &  parameters,
const size_t  i,
GradType &  gradient,
const size_t  batchSize 
)

EvaluateWithGradient function for the WGAN-GP.

This function gives the performance of the WGAN-GP on the current input, while updating Gradients.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ Forward()

void Forward ( const arma::mat &  input)

This function does a forward pass through the GAN network.

Parameters
inputSampled noise.

◆ Generator() [1/2]

const Model& Generator ( ) const
inline

Return the generator of the GAN.

Definition at line 308 of file gan.hpp.

◆ Generator() [2/2]

Model& Generator ( )
inline

Modify the generator of the GAN.

Definition at line 310 of file gan.hpp.

◆ Gradient() [1/3]

std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, void>::type Gradient ( const arma::mat &  parameters,
const size_t  i,
arma::mat &  gradient,
const size_t  batchSize 
)

Gradient function for Standard GAN and DCGAN.

This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.

Parameters
parameterspresent parameters of the network.
iIndex of the predictors.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ Gradient() [2/3]

std::enable_if<std::is_same<Policy, WGAN>::value, void>::type Gradient ( const arma::mat &  parameters,
const size_t  i,
arma::mat &  gradient,
const size_t  batchSize 
)

Gradient function for WGAN.

This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.

Parameters
parameterspresent parameters of the network.
iIndex of the predictors.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ Gradient() [3/3]

std::enable_if<std::is_same<Policy, WGANGP>::value, void>::type Gradient ( const arma::mat &  parameters,
const size_t  i,
arma::mat &  gradient,
const size_t  batchSize 
)

Gradient function for WGAN-GP.

This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.

Parameters
parameterspresent parameters of the network.
iIndex of the predictors.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ NumFunctions()

size_t NumFunctions ( ) const
inline

Return the number of separable functions (the number of predictor points).

Definition at line 317 of file gan.hpp.

◆ Parameters() [1/2]

const arma::mat& Parameters ( ) const
inline

Return the parameters of the network.

Definition at line 303 of file gan.hpp.

◆ Parameters() [2/2]

arma::mat& Parameters ( )
inline

Modify the parameters of the network.

Definition at line 305 of file gan.hpp.

◆ Predict()

void Predict ( arma::mat  input,
arma::mat &  output 
)

This function predicts the output of the network on the given input.

Parameters
inputThe input of the Generator network.
outputResult of the Discriminator network.

◆ Predictors() [1/2]

const arma::mat& Predictors ( ) const
inline

Get the matrix of data points (predictors).

Definition at line 325 of file gan.hpp.

◆ Predictors() [2/2]

arma::mat& Predictors ( )
inline

Modify the matrix of data points (predictors).

Definition at line 327 of file gan.hpp.

References GAN< Model, InitializationRuleType, Noise, PolicyType >::serialize().

◆ Reset()

void Reset ( )

◆ ResetData()

void ResetData ( arma::mat  trainData)

Initialize the generator, discriminator and weights of the model for training.

This function won't actually trigger training process.

Parameters
trainDataThe data points of real distribution.

◆ Responses() [1/2]

const arma::mat& Responses ( ) const
inline

Get the matrix of responses to the input data points.

Definition at line 320 of file gan.hpp.

◆ Responses() [2/2]

arma::mat& Responses ( )
inline

Modify the matrix of responses to the input data points.

Definition at line 322 of file gan.hpp.

◆ serialize()

void serialize ( Archive &  ar,
const uint32_t   
)

◆ Shuffle()

void Shuffle ( )

Shuffle the order of function visitation.

This may be called by the optimizer.

◆ Train()

double Train ( arma::mat  trainData,
OptimizerType &  Optimizer,
CallbackTypes &&...  callbacks 
)

Train function.

Template Parameters
OptimizerTypeType of optimizer to use to train the model.
CallbackTypesTypes of Callback functions.
Parameters
trainDataThe data points of real distribution.
OptimizerInstantiated optimizer used to train the model.
callbacksCallback function for ensmallen optimizer OptimizerType. See https://www.ensmallen.org/docs.html#callback-documentation.
Returns
The final objective of the trained model (NaN or Inf on error).

The documentation for this class was generated from the following file:
  • /home/ryan/src/mlpack.org/_src/mlpack-git/src/mlpack/methods/ann/gan/gan.hpp