prioritized_replay.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP
13 #define MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include "sumtree.hpp"
17 
18 namespace mlpack {
19 namespace rl {
20 
38 template <typename EnvironmentType>
40 {
41  public:
43  using ActionType = typename EnvironmentType::Action;
44 
46  using StateType = typename EnvironmentType::State;
47 
48  struct Transition
49  {
52  double reward;
54  bool isEnd;
55  };
56 
61  batchSize(0),
62  capacity(0),
63  position(0),
64  full(false),
65  alpha(0),
66  maxPriority(0),
67  initialBeta(0),
68  beta(0),
69  replayBetaIters(0),
70  nSteps(0)
71  { /* Nothing to do here. */ }
72 
82  PrioritizedReplay(const size_t batchSize,
83  const size_t capacity,
84  const double alpha,
85  const size_t nSteps = 1,
86  const size_t dimension = StateType::dimension) :
87  batchSize(batchSize),
88  capacity(capacity),
89  position(0),
90  full(false),
91  alpha(alpha),
92  maxPriority(1.0),
93  initialBeta(0.6),
94  replayBetaIters(10000),
95  nSteps(nSteps),
96  states(dimension, capacity),
97  actions(capacity),
98  rewards(capacity),
99  nextStates(dimension, capacity),
100  isTerminal(capacity)
101  {
102  size_t size = 1;
103  while (size < capacity)
104  {
105  size *= 2;
106  }
107 
108  beta = initialBeta;
109  idxSum = SumTree<double>(size);
110  }
111 
124  double reward,
126  bool isEnd,
127  const double& discount)
128  {
129  nStepBuffer.push_back({state, action, reward, nextState, isEnd});
130 
131  // Single step transition is not ready.
132  if (nStepBuffer.size() < nSteps)
133  return;
134 
135  // To keep the queue size fixed to nSteps.
136  if (nStepBuffer.size() > nSteps)
137  nStepBuffer.pop_front();
138 
139  // Before moving ahead, lets confirm if our fixed size buffer works.
140  assert(nStepBuffer.size() == nSteps);
141 
142  // Make a n-step transition.
143  GetNStepInfo(reward, nextState, isEnd, discount);
144 
145  state = nStepBuffer.front().state;
146  action = nStepBuffer.front().action;
147  states.col(position) = state.Encode();
148  actions[position] = action;
149  rewards(position) = reward;
150  nextStates.col(position) = nextState.Encode();
151  isTerminal(position) = isEnd;
152 
153  idxSum.Set(position, maxPriority * alpha);
154 
155  position++;
156  if (position == capacity)
157  {
158  full = true;
159  position = 0;
160  }
161  }
162 
171  void GetNStepInfo(double& reward,
173  bool& isEnd,
174  const double& discount)
175  {
176  reward = nStepBuffer.back().reward;
177  nextState = nStepBuffer.back().nextState;
178  isEnd = nStepBuffer.back().isEnd;
179 
180  // Should start from the second last transition in buffer.
181  for (int i = nStepBuffer.size() - 2; i >= 0; i--)
182  {
183  bool iE = nStepBuffer[i].isEnd;
184  reward = nStepBuffer[i].reward + discount * reward * (1 - iE);
185  if (iE)
186  {
187  nextState = nStepBuffer[i].nextState;
188  isEnd = iE;
189  }
190  }
191  }
192 
198  arma::ucolvec SampleProportional()
199  {
200  arma::ucolvec idxes(batchSize);
201  double totalSum = idxSum.Sum(0, (full ? capacity : position));
202  double sumPerRange = totalSum / batchSize;
203  for (size_t bt = 0; bt < batchSize; bt++)
204  {
205  const double mass = arma::randu() * sumPerRange + bt * sumPerRange;
206  idxes(bt) = idxSum.FindPrefixSum(mass);
207  }
208  return idxes;
209  }
210 
221  void Sample(arma::mat& sampledStates,
222  std::vector<ActionType>& sampledActions,
223  arma::rowvec& sampledRewards,
224  arma::mat& sampledNextStates,
225  arma::irowvec& isTerminal)
226  {
227  sampledIndices = SampleProportional();
228  BetaAnneal();
229 
230  sampledStates = states.cols(sampledIndices);
231  for (size_t t = 0; t < sampledIndices.n_rows; t ++)
232  sampledActions.push_back(actions[sampledIndices[t]]);
233  sampledRewards = rewards.elem(sampledIndices).t();
234  sampledNextStates = nextStates.cols(sampledIndices);
235  isTerminal = this->isTerminal.elem(sampledIndices).t();
236 
237  // Calculate the weights of sampled transitions.
238 
239  size_t numSample = full ? capacity : position;
240  weights = arma::rowvec(sampledIndices.n_rows);
241 
242  for (size_t i = 0; i < sampledIndices.n_rows; ++i)
243  {
244  double p_sample = idxSum.Get(sampledIndices(i)) / idxSum.Sum();
245  weights(i) = pow(numSample * p_sample, -beta);
246  }
247  weights /= weights.max();
248  }
249 
256  void UpdatePriorities(arma::ucolvec& indices, arma::colvec& priorities)
257  {
258  arma::colvec alphaPri = alpha * priorities;
259  maxPriority = std::max(maxPriority, arma::max(priorities));
260  idxSum.BatchUpdate(indices, alphaPri);
261  }
262 
268  const size_t& Size()
269  {
270  return full ? capacity : position;
271  }
272 
276  void BetaAnneal()
277  {
278  beta = beta + (1 - initialBeta) * 1.0 / replayBetaIters;
279  }
280 
289  void Update(arma::mat target,
290  std::vector<ActionType> sampledActions,
291  arma::mat nextActionValues,
292  arma::mat& gradients)
293  {
294  arma::colvec tdError(target.n_cols);
295  for (size_t i = 0; i < target.n_cols; i ++)
296  {
297  tdError(i) = nextActionValues(sampledActions[i].action, i) -
298  target(sampledActions[i].action, i);
299  }
300  tdError = arma::abs(tdError);
301  UpdatePriorities(sampledIndices, tdError);
302 
303  // Update the gradient
304  gradients = arma::mean(weights) * gradients;
305  }
306 
308  const size_t& NSteps() const { return nSteps; }
309 
310  private:
312  size_t batchSize;
313 
315  size_t capacity;
316 
318  size_t position;
319 
321  bool full;
322 
325  double alpha;
326 
328  double maxPriority;
329 
331  double initialBeta;
332 
334  double beta;
335 
337  size_t replayBetaIters;
338 
340  SumTree<double> idxSum;
341 
343  arma::ucolvec sampledIndices;
344 
346  arma::rowvec weights;
347 
349  size_t nSteps;
350 
352  std::deque<Transition> nStepBuffer;
353 
355  arma::mat states;
356 
358  std::vector<ActionType> actions;
359 
361  arma::rowvec rewards;
362 
364  arma::mat nextStates;
365 
367  arma::irowvec isTerminal;
368 };
369 
370 } // namespace rl
371 } // namespace mlpack
372 
373 #endif
void BetaAnneal()
Annealing the beta.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void GetNStepInfo(double &reward, StateType &nextState, bool &isEnd, const double &discount)
Get the reward, next state and terminal boolean for nth step.
arma::ucolvec SampleProportional()
Sample some experience according to their priorities.
Implementation of prioritized experience replay.
PrioritizedReplay(const size_t batchSize, const size_t capacity, const double alpha, const size_t nSteps=1, const size_t dimension=StateType::dimension)
Construct an instance of prioritized experience replay class.
const size_t & Size()
Get the number of transitions in the memory.
see subsection cli_alt_reg_tut Alternate DET regularization The usual regularized error f $R_ alpha(t)\f$ of a node \f $t\f$ is given by
Definition: det.txt:344
PrioritizedReplay()
Default constructor.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
const size_t & NSteps() const
Get the number of steps for n-step agent.
void UpdatePriorities(arma::ucolvec &indices, arma::colvec &priorities)
Update priorities of sampled transitions.
typename EnvironmentType::State StateType
Convenient typedef for state.
void Update(arma::mat target, std::vector< ActionType > sampledActions, arma::mat nextActionValues, arma::mat &gradients)
Update the priorities of transitions and Update the gradients.
void Store(StateType state, ActionType action, double reward, StateType nextState, bool isEnd, const double &discount)
Store the given experience and set the priorities for the given experience.
void Sample(arma::mat &sampledStates, std::vector< ActionType > &sampledActions, arma::rowvec &sampledRewards, arma::mat &sampledNextStates, arma::irowvec &isTerminal)
Sample some experience according to their priorities.