12 #ifndef MLPACK_METHODS_RL_SIMPLE_DQN_HPP 13 #define MLPACK_METHODS_RL_SIMPLE_DQN_HPP 60 const bool isNoisy =
false,
61 InitType init = InitType(),
62 OutputLayerType outputLayer = OutputLayerType()):
63 network(outputLayer, init),
66 network.Add(
new Linear<>(inputDim, h1));
70 noisyLayerIndex.push_back(network.Model().size());
73 noisyLayerIndex.push_back(network.Model().size());
80 network.Add(
new Linear<>(h2, outputDim));
90 SimpleDQN(NetworkType& network,
const bool isNoisy =
false):
106 void Predict(
const arma::mat state, arma::mat& actionValue)
108 network.Predict(state, actionValue);
117 void Forward(
const arma::mat state, arma::mat& target)
119 network.Forward(state, target);
127 network.ResetParameters();
135 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
137 boost::get<NoisyLinear<>*>
138 (network.Model()[noisyLayerIndex[i]])->ResetNoise();
143 const arma::mat&
Parameters()
const {
return network.Parameters(); }
154 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
156 network.Backward(state, target, gradient);
167 std::vector<size_t> noisyLayerIndex;
Artificial Neural Network.
void ResetParameters()
Resets the parameters of the network.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
SimpleDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of SimpleDQN class.
Linear algebra utility functions, generally performed on matrices or vectors.
SimpleDQN()
Default constructor.
arma::mat & Parameters()
Modify the Parameters.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
void Forward(const arma::mat state, arma::mat &target)
Perform the forward pass of the states in real batch mode.
const arma::mat & Parameters() const
Return the Parameters.
Implementation of the base layer.
SimpleDQN(NetworkType &network, const bool isNoisy=false)
Construct an instance of SimpleDQN class from a pre-constructed network.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Implementation of the NoisyLinear layer class.
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
The mean squared error performance function measures the network's performance according to the mean ...
Implementation of a standard feed forward network.
This class is used to initialize weigth matrix with a gaussian.