12 #ifndef MLPACK_METHODS_RL_CATEGORICAL_DQN_HPP 13 #define MLPACK_METHODS_RL_CATEGORICAL_DQN_HPP 21 #include "../training_config.hpp" 57 network(), atomSize(0), vMin(0.0), vMax(0.0), isNoisy(false)
77 const bool isNoisy =
false,
78 InitType init = InitType(),
79 OutputLayerType outputLayer = OutputLayerType()):
80 network(outputLayer, init),
81 atomSize(config.AtomSize()),
86 network.Add(
new Linear<>(inputDim, h1));
90 noisyLayerIndex.push_back(network.Model().size());
93 noisyLayerIndex.push_back(network.Model().size());
100 network.Add(
new Linear<>(h2, outputDim * atomSize));
114 const bool isNoisy =
false):
115 network(
std::move(network)),
116 atomSize(config.AtomSize()),
133 void Predict(
const arma::mat state, arma::mat& actionValue)
136 network.Predict(state, q_atoms);
137 activations.copy_size(q_atoms);
138 actionValue.set_size(q_atoms.n_rows / atomSize, q_atoms.n_cols);
139 arma::rowvec support = arma::linspace<arma::rowvec>(vMin, vMax, atomSize);
140 for (
size_t i = 0; i < q_atoms.n_rows; i += atomSize)
142 arma::mat activation = activations.rows(i, i + atomSize - 1);
143 arma::mat input = q_atoms.rows(i, i + atomSize - 1);
144 softMax.Forward(input, activation);
145 activations.rows(i, i + atomSize - 1) = activation;
146 actionValue.row(i/atomSize) = support * activation;
156 void Forward(
const arma::mat state, arma::mat& dist)
159 network.Forward(state, q_atoms);
160 activations.copy_size(q_atoms);
161 for (
size_t i = 0; i < q_atoms.n_rows; i += atomSize)
163 arma::mat activation = activations.rows(i, i + atomSize - 1);
164 arma::mat input = q_atoms.rows(i, i + atomSize - 1);
165 softMax.Forward(input, activation);
166 activations.rows(i, i + atomSize - 1) = activation;
176 network.ResetParameters();
184 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
186 boost::get<NoisyLinear<>*>
187 (network.Model()[noisyLayerIndex[i]])->ResetNoise();
192 const arma::mat&
Parameters()
const {
return network.Parameters(); }
204 arma::mat& lossGradients,
207 arma::mat activationGradients(arma::size(activations));
208 for (
size_t i = 0; i < activations.n_rows; i += atomSize)
210 arma::mat activationGrad;
211 arma::mat lossGrad = lossGradients.rows(i, i + atomSize - 1);
212 arma::mat activation = activations.rows(i, i + atomSize - 1);
213 softMax.Backward(activation, lossGrad, activationGrad);
214 activationGradients.rows(i, i + atomSize - 1) = activationGrad;
216 network.Backward(state, activationGradients, gradient);
236 std::vector<size_t> noisyLayerIndex;
242 arma::mat activations;
Artificial Neural Network.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
CategoricalDQN(const int inputDim, const int h1, const int h2, const int outputDim, TrainingConfig config, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of CategoricalDQN class.
Linear algebra utility functions, generally performed on matrices or vectors.
CategoricalDQN(NetworkType &network, TrainingConfig config, const bool isNoisy=false)
Construct an instance of CategoricalDQN class from a pre-constructed network.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
arma::mat & Parameters()
Modify the Parameters.
The empty loss does nothing, letting the user calculate the loss outside the model.
Implementation of the Softmax layer.
void Forward(const arma::mat state, arma::mat &dist)
Perform the forward pass of the states in real batch mode.
const arma::mat & Parameters() const
Return the Parameters.
Implementation of the base layer.
Implementation of the NoisyLinear layer class.
CategoricalDQN()
Default constructor.
Implementation of a standard feed forward network.
Implementation of the Categorical Deep Q-Learning network.
void ResetParameters()
Resets the parameters of the network.
void Backward(const arma::mat state, arma::mat &lossGradients, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
This class is used to initialize weigth matrix with a gaussian.