12 #ifndef MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP    13 #define MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP    38 template <
typename EnvironmentType>
    83                     const size_t capacity,
    85                     const size_t nSteps = 1,
    86                     const size_t dimension = StateType::dimension) :
    94       replayBetaIters(10000),
    96       states(dimension, capacity),
    99       nextStates(dimension, capacity),
   103     while (size < capacity)
   127              const double& discount)
   132     if (nStepBuffer.size() < nSteps)
   136     if (nStepBuffer.size() > nSteps)
   137       nStepBuffer.pop_front();
   140     assert(nStepBuffer.size() == nSteps);
   145     state = nStepBuffer.front().state;
   146     action = nStepBuffer.front().action;
   147     states.col(position) = state.Encode();
   148     actions[position] = 
action;
   149     rewards(position) = 
reward;
   150     nextStates.col(position) = nextState.Encode();
   151     isTerminal(position) = 
isEnd;
   153     idxSum.Set(position, maxPriority * 
alpha);
   156     if (position == capacity)
   174                     const double& discount)
   176     reward = nStepBuffer.back().reward;
   177     nextState = nStepBuffer.back().nextState;
   178     isEnd = nStepBuffer.back().isEnd;
   181     for (
int i = nStepBuffer.size() - 2; i >= 0; i--)
   183       bool iE = nStepBuffer[i].isEnd;
   184       reward = nStepBuffer[i].reward + discount * reward * (1 - iE);
   187         nextState = nStepBuffer[i].nextState;
   200     arma::ucolvec idxes(batchSize);
   201     double totalSum = idxSum.Sum(0, (full ? capacity : position));
   202     double sumPerRange = totalSum / batchSize;
   203     for (
size_t bt = 0; bt < batchSize; bt++)
   205       const double mass = arma::randu() * sumPerRange + bt * sumPerRange;
   206       idxes(bt) = idxSum.FindPrefixSum(mass);
   222               std::vector<ActionType>& sampledActions,
   223               arma::rowvec& sampledRewards,
   224               arma::mat& sampledNextStates,
   225               arma::irowvec& isTerminal)
   230     sampledStates = states.cols(sampledIndices);
   231     for (
size_t t = 0; t < sampledIndices.n_rows; t ++)
   232       sampledActions.push_back(actions[sampledIndices[t]]);
   233     sampledRewards = rewards.elem(sampledIndices).t();
   234     sampledNextStates = nextStates.cols(sampledIndices);
   235     isTerminal = this->isTerminal.elem(sampledIndices).t();
   239     size_t numSample = full ? capacity : position;
   240     weights = arma::rowvec(sampledIndices.n_rows);
   242     for (
size_t i = 0; i < sampledIndices.n_rows; ++i)
   244       double p_sample = idxSum.Get(sampledIndices(i)) / idxSum.Sum();
   245       weights(i) = pow(numSample * p_sample, -beta);
   247     weights /= weights.max();
   258       arma::colvec alphaPri = 
alpha * priorities;
   259       maxPriority = std::max(maxPriority, arma::max(priorities));
   260       idxSum.BatchUpdate(indices, alphaPri);
   270     return full ? capacity : position;
   278     beta = beta + (1 - initialBeta) * 1.0 / replayBetaIters;
   290               std::vector<ActionType> sampledActions,
   291               arma::mat nextActionValues,
   292               arma::mat& gradients)
   294     arma::colvec tdError(target.n_cols);
   295     for (
size_t i = 0; i < target.n_cols; i ++)
   297       tdError(i) = nextActionValues(sampledActions[i].
action, i) -
   298           target(sampledActions[i].action, i);
   300     tdError = arma::abs(tdError);
   304     gradients = arma::mean(weights) * gradients;
   308   const size_t& 
NSteps()
 const { 
return nSteps; }
   337   size_t replayBetaIters;
   343   arma::ucolvec sampledIndices;
   346   arma::rowvec weights;
   352   std::deque<Transition> nStepBuffer;
   358   std::vector<ActionType> actions;
   361   arma::rowvec rewards;
   364   arma::mat nextStates;
   367   arma::irowvec isTerminal;
 
void BetaAnneal()
Annealing the beta. 
 
Linear algebra utility functions, generally performed on matrices or vectors. 
 
The core includes that mlpack expects; standard C++ includes and Armadillo. 
 
void GetNStepInfo(double &reward, StateType &nextState, bool &isEnd, const double &discount)
Get the reward, next state and terminal boolean for nth step. 
 
arma::ucolvec SampleProportional()
Sample some experience according to their priorities. 
 
Implementation of prioritized experience replay. 
 
PrioritizedReplay(const size_t batchSize, const size_t capacity, const double alpha, const size_t nSteps=1, const size_t dimension=StateType::dimension)
Construct an instance of prioritized experience replay class. 
 
const size_t & Size()
Get the number of transitions in the memory. 
 
see subsection cli_alt_reg_tut Alternate DET regularization The usual regularized error f $R_ alpha(t)\f$ of a node \f $t\f$ is given by
 
PrioritizedReplay()
Default constructor. 
 
typename EnvironmentType::Action ActionType
Convenient typedef for action. 
 
const size_t & NSteps() const
Get the number of steps for n-step agent. 
 
void UpdatePriorities(arma::ucolvec &indices, arma::colvec &priorities)
Update priorities of sampled transitions. 
 
typename EnvironmentType::State StateType
Convenient typedef for state. 
 
void Update(arma::mat target, std::vector< ActionType > sampledActions, arma::mat nextActionValues, arma::mat &gradients)
Update the priorities of transitions and Update the gradients. 
 
void Store(StateType state, ActionType action, double reward, StateType nextState, bool isEnd, const double &discount)
Store the given experience and set the priorities for the given experience. 
 
void Sample(arma::mat &sampledStates, std::vector< ActionType > &sampledActions, arma::rowvec &sampledRewards, arma::mat &sampledNextStates, arma::irowvec &isTerminal)
Sample some experience according to their priorities.