12 #ifndef MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP 13 #define MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP 38 template <
typename EnvironmentType>
83 const size_t capacity,
85 const size_t nSteps = 1,
86 const size_t dimension = StateType::dimension) :
94 replayBetaIters(10000),
96 states(dimension, capacity),
99 nextStates(dimension, capacity),
103 while (size < capacity)
127 const double& discount)
132 if (nStepBuffer.size() < nSteps)
136 if (nStepBuffer.size() > nSteps)
137 nStepBuffer.pop_front();
140 assert(nStepBuffer.size() == nSteps);
145 state = nStepBuffer.front().state;
146 action = nStepBuffer.front().action;
147 states.col(position) = state.Encode();
148 actions[position] =
action;
149 rewards(position) =
reward;
150 nextStates.col(position) = nextState.Encode();
151 isTerminal(position) =
isEnd;
153 idxSum.Set(position, maxPriority *
alpha);
156 if (position == capacity)
174 const double& discount)
176 reward = nStepBuffer.back().reward;
177 nextState = nStepBuffer.back().nextState;
178 isEnd = nStepBuffer.back().isEnd;
181 for (
int i = nStepBuffer.size() - 2; i >= 0; i--)
183 bool iE = nStepBuffer[i].isEnd;
184 reward = nStepBuffer[i].reward + discount * reward * (1 - iE);
187 nextState = nStepBuffer[i].nextState;
200 arma::ucolvec idxes(batchSize);
201 double totalSum = idxSum.Sum(0, (full ? capacity : position));
202 double sumPerRange = totalSum / batchSize;
203 for (
size_t bt = 0; bt < batchSize; bt++)
205 const double mass = arma::randu() * sumPerRange + bt * sumPerRange;
206 idxes(bt) = idxSum.FindPrefixSum(mass);
222 std::vector<ActionType>& sampledActions,
223 arma::rowvec& sampledRewards,
224 arma::mat& sampledNextStates,
225 arma::irowvec& isTerminal)
230 sampledStates = states.cols(sampledIndices);
231 for (
size_t t = 0; t < sampledIndices.n_rows; t ++)
232 sampledActions.push_back(actions[sampledIndices[t]]);
233 sampledRewards = rewards.elem(sampledIndices).t();
234 sampledNextStates = nextStates.cols(sampledIndices);
235 isTerminal = this->isTerminal.elem(sampledIndices).t();
239 size_t numSample = full ? capacity : position;
240 weights = arma::rowvec(sampledIndices.n_rows);
242 for (
size_t i = 0; i < sampledIndices.n_rows; ++i)
244 double p_sample = idxSum.Get(sampledIndices(i)) / idxSum.Sum();
245 weights(i) = pow(numSample * p_sample, -beta);
247 weights /= weights.max();
258 arma::colvec alphaPri =
alpha * priorities;
259 maxPriority = std::max(maxPriority, arma::max(priorities));
260 idxSum.BatchUpdate(indices, alphaPri);
270 return full ? capacity : position;
278 beta = beta + (1 - initialBeta) * 1.0 / replayBetaIters;
290 std::vector<ActionType> sampledActions,
291 arma::mat nextActionValues,
292 arma::mat& gradients)
294 arma::colvec tdError(target.n_cols);
295 for (
size_t i = 0; i < target.n_cols; i ++)
297 tdError(i) = nextActionValues(sampledActions[i].
action, i) -
298 target(sampledActions[i].action, i);
300 tdError = arma::abs(tdError);
304 gradients = arma::mean(weights) * gradients;
308 const size_t&
NSteps()
const {
return nSteps; }
337 size_t replayBetaIters;
343 arma::ucolvec sampledIndices;
346 arma::rowvec weights;
352 std::deque<Transition> nStepBuffer;
358 std::vector<ActionType> actions;
361 arma::rowvec rewards;
364 arma::mat nextStates;
367 arma::irowvec isTerminal;
void BetaAnneal()
Annealing the beta.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void GetNStepInfo(double &reward, StateType &nextState, bool &isEnd, const double &discount)
Get the reward, next state and terminal boolean for nth step.
arma::ucolvec SampleProportional()
Sample some experience according to their priorities.
Implementation of prioritized experience replay.
PrioritizedReplay(const size_t batchSize, const size_t capacity, const double alpha, const size_t nSteps=1, const size_t dimension=StateType::dimension)
Construct an instance of prioritized experience replay class.
const size_t & Size()
Get the number of transitions in the memory.
see subsection cli_alt_reg_tut Alternate DET regularization The usual regularized error f $R_ alpha(t)\f$ of a node \f $t\f$ is given by
PrioritizedReplay()
Default constructor.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
const size_t & NSteps() const
Get the number of steps for n-step agent.
void UpdatePriorities(arma::ucolvec &indices, arma::colvec &priorities)
Update priorities of sampled transitions.
typename EnvironmentType::State StateType
Convenient typedef for state.
void Update(arma::mat target, std::vector< ActionType > sampledActions, arma::mat nextActionValues, arma::mat &gradients)
Update the priorities of transitions and Update the gradients.
void Store(StateType state, ActionType action, double reward, StateType nextState, bool isEnd, const double &discount)
Store the given experience and set the priorities for the given experience.
void Sample(arma::mat &sampledStates, std::vector< ActionType > &sampledActions, arma::rowvec &sampledRewards, arma::mat &sampledNextStates, arma::irowvec &isTerminal)
Sample some experience according to their priorities.