12 #ifndef MLPACK_METHODS_RL_DUELING_DQN_HPP 13 #define MLPACK_METHODS_RL_DUELING_DQN_HPP 67 concat->
Add(valueNetwork);
68 concat->
Add(advantageNetwork);
70 completeNetwork.Add(featureNetwork);
71 completeNetwork.Add(concat);
89 const bool isNoisy =
false,
90 InitType init = InitType(),
91 OutputLayerType outputLayer = OutputLayerType()):
92 completeNetwork(outputLayer, init),
104 noisyLayerIndex.push_back(valueNetwork->Model().size());
111 noisyLayerIndex.push_back(valueNetwork->Model().size());
117 valueNetwork->Add(
new Linear<>(h1, h2));
119 valueNetwork->Add(
new Linear<>(h2, 1));
123 advantageNetwork->
Add(
new Linear<>(h2, outputDim));
127 concat->
Add(valueNetwork);
128 concat->Add(advantageNetwork);
131 completeNetwork.Add(featureNetwork);
132 completeNetwork.Add(concat);
133 this->ResetParameters();
145 AdvantageNetworkType& advantageNetwork,
146 ValueNetworkType& valueNetwork,
147 const bool isNoisy =
false):
148 featureNetwork(featureNetwork),
149 advantageNetwork(advantageNetwork),
150 valueNetwork(valueNetwork),
154 concat->
Add(valueNetwork);
155 concat->Add(advantageNetwork);
157 completeNetwork.Add(featureNetwork);
158 completeNetwork.Add(concat);
159 this->ResetParameters();
169 *valueNetwork = *model.valueNetwork;
170 *advantageNetwork = *model.advantageNetwork;
171 *featureNetwork = *model.featureNetwork;
172 isNoisy = model.isNoisy;
173 noisyLayerIndex = model.noisyLayerIndex;
187 void Predict(
const arma::mat state, arma::mat& actionValue)
189 arma::mat advantage, value, networkOutput;
190 completeNetwork.Predict(state, networkOutput);
191 value = networkOutput.row(0);
192 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
193 actionValue = advantage.each_row() +
194 (value - arma::mean(advantage));
203 void Forward(
const arma::mat state, arma::mat& actionValue)
205 arma::mat advantage, value, networkOutput;
206 completeNetwork.Forward(state, networkOutput);
207 value = networkOutput.row(0);
208 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
209 actionValue = advantage.each_row() +
210 (value - arma::mean(advantage));
211 this->actionValues = actionValue;
221 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
224 lossFunction.Backward(this->actionValues, target, gradLoss);
226 arma::mat gradValue = arma::sum(gradLoss);
227 arma::mat gradAdvantage = gradLoss.each_row() - arma::mean(gradLoss);
229 arma::mat grad = arma::join_cols(gradValue, gradAdvantage);
230 completeNetwork.Backward(state, grad, gradient);
238 completeNetwork.ResetParameters();
246 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
248 boost::get<NoisyLinear<>*>
249 (valueNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
250 boost::get<NoisyLinear<>*>
251 (advantageNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
256 const arma::mat&
Parameters()
const {
return completeNetwork.Parameters(); }
258 arma::mat&
Parameters() {
return completeNetwork.Parameters(); }
262 CompleteNetworkType completeNetwork;
268 FeatureNetworkType* featureNetwork;
271 AdvantageNetworkType* advantageNetwork;
274 ValueNetworkType* valueNetwork;
280 std::vector<size_t> noisyLayerIndex;
283 arma::mat actionValues;
Artificial Neural Network.
DuelingDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of DuelingDQN class.
Linear algebra utility functions, generally performed on matrices or vectors.
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
void Forward(const arma::mat state, arma::mat &actionValue)
Perform the forward pass of the states in real batch mode.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Implementation of the Dueling Deep Q-Learning network.
The empty loss does nothing, letting the user calculate the loss outside the model.
Implementation of the base layer.
DuelingDQN()
Default constructor.
Implementation of the Concat class.
Implementation of the NoisyLinear layer class.
DuelingDQN(const DuelingDQN &)
Copy constructor.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
DuelingDQN(FeatureNetworkType &featureNetwork, AdvantageNetworkType &advantageNetwork, ValueNetworkType &valueNetwork, const bool isNoisy=false)
Construct an instance of DuelingDQN class from a pre-constructed network.
The mean squared error performance function measures the network's performance according to the mean ...
void ResetParameters()
Resets the parameters of the network.
arma::mat & Parameters()
Modify the Parameters.
Implementation of a standard feed forward network.
const arma::mat & Parameters() const
Return the Parameters.
Implementation of the Sequential class.
This class is used to initialize weigth matrix with a gaussian.