random_replay.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP
13 #define MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include <cassert>
17 
18 namespace mlpack {
19 namespace rl {
20 
43 template <typename EnvironmentType>
45 {
46  public:
48  using ActionType = typename EnvironmentType::Action;
49 
51  using StateType = typename EnvironmentType::State;
52 
53  struct Transition
54  {
57  double reward;
59  bool isEnd;
60  };
61 
63  batchSize(0),
64  capacity(0),
65  position(0),
66  full(false),
67  nSteps(0)
68  { /* Nothing to do here. */ }
69 
78  RandomReplay(const size_t batchSize,
79  const size_t capacity,
80  const size_t nSteps = 1,
81  const size_t dimension = StateType::dimension) :
82  batchSize(batchSize),
83  capacity(capacity),
84  position(0),
85  full(false),
86  nSteps(nSteps),
87  states(dimension, capacity),
88  actions(capacity),
89  rewards(capacity),
90  nextStates(dimension, capacity),
91  isTerminal(capacity)
92  { /* Nothing to do here. */ }
93 
106  double reward,
108  bool isEnd,
109  const double& discount)
110  {
111  nStepBuffer.push_back({state, action, reward, nextState, isEnd});
112 
113  // Single step transition is not ready.
114  if (nStepBuffer.size() < nSteps)
115  return;
116 
117  // To keep the queue size fixed to nSteps.
118  if (nStepBuffer.size() > nSteps)
119  nStepBuffer.pop_front();
120 
121  // Before moving ahead, lets confirm if our fixed size buffer works.
122  assert(nStepBuffer.size() == nSteps);
123 
124  // Make a n-step transition.
125  GetNStepInfo(reward, nextState, isEnd, discount);
126 
127  state = nStepBuffer.front().state;
128  action = nStepBuffer.front().action;
129 
130  states.col(position) = state.Encode();
131  actions[position] = action;
132  rewards(position) = reward;
133  nextStates.col(position) = nextState.Encode();
134  isTerminal(position) = isEnd;
135  position++;
136  if (position == capacity)
137  {
138  full = true;
139  position = 0;
140  }
141  }
142 
151  void GetNStepInfo(double& reward,
153  bool& isEnd,
154  const double& discount)
155  {
156  reward = nStepBuffer.back().reward;
157  nextState = nStepBuffer.back().nextState;
158  isEnd = nStepBuffer.back().isEnd;
159 
160  // Should start from the second last transition in buffer.
161  for (int i = nStepBuffer.size() - 2; i >= 0; i--)
162  {
163  bool iE = nStepBuffer[i].isEnd;
164  reward = nStepBuffer[i].reward + discount * reward * (1 - iE);
165  if (iE)
166  {
167  nextState = nStepBuffer[i].nextState;
168  isEnd = iE;
169  }
170  }
171  }
172 
183  void Sample(arma::mat& sampledStates,
184  std::vector<ActionType>& sampledActions,
185  arma::rowvec& sampledRewards,
186  arma::mat& sampledNextStates,
187  arma::irowvec& isTerminal)
188  {
189  size_t upperBound = full ? capacity : position;
190  arma::uvec sampledIndices = arma::randi<arma::uvec>(
191  batchSize, arma::distr_param(0, upperBound - 1));
192 
193  sampledStates = states.cols(sampledIndices);
194  for (size_t t = 0; t < sampledIndices.n_rows; t ++)
195  sampledActions.push_back(actions[sampledIndices[t]]);
196  sampledRewards = rewards.elem(sampledIndices).t();
197  sampledNextStates = nextStates.cols(sampledIndices);
198  isTerminal = this->isTerminal.elem(sampledIndices).t();
199  }
200 
206  const size_t& Size()
207  {
208  return full ? capacity : position;
209  }
210 
219  void Update(arma::mat /* target */,
220  std::vector<ActionType> /* sampledActions */,
221  arma::mat /* nextActionValues */,
222  arma::mat& /* gradients */)
223  {
224  /* Do nothing for random replay. */
225  }
226 
228  const size_t& NSteps() const { return nSteps; }
229 
230  private:
232  size_t batchSize;
233 
235  size_t capacity;
236 
238  size_t position;
239 
241  bool full;
242 
244  size_t nSteps;
245 
247  std::deque<Transition> nStepBuffer;
248 
250  arma::mat states;
251 
253  std::vector<ActionType> actions;
254 
256  arma::rowvec rewards;
257 
259  arma::mat nextStates;
260 
262  arma::irowvec isTerminal;
263 };
264 
265 } // namespace rl
266 } // namespace mlpack
267 
268 #endif
void Store(StateType state, ActionType action, double reward, StateType nextState, bool isEnd, const double &discount)
Store the given experience.
void Update(arma::mat, std::vector< ActionType >, arma::mat, arma::mat &)
Update the priorities of transitions and Update the gradients.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
typename EnvironmentType::State StateType
Convenient typedef for state.
RandomReplay(const size_t batchSize, const size_t capacity, const size_t nSteps=1, const size_t dimension=StateType::dimension)
Construct an instance of random experience replay class.
void Sample(arma::mat &sampledStates, std::vector< ActionType > &sampledActions, arma::rowvec &sampledRewards, arma::mat &sampledNextStates, arma::irowvec &isTerminal)
Sample some experiences.
const size_t & Size()
Get the number of transitions in the memory.
void GetNStepInfo(double &reward, StateType &nextState, bool &isEnd, const double &discount)
Get the reward, next state and terminal boolean for nth step.
const size_t & NSteps() const
Get the number of steps for n-step agent.
Implementation of random experience replay.