12 #ifndef MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP 13 #define MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP 43 template <
typename EnvironmentType>
79 const size_t capacity,
80 const size_t nSteps = 1,
81 const size_t dimension = StateType::dimension) :
87 states(dimension, capacity),
90 nextStates(dimension, capacity),
109 const double& discount)
114 if (nStepBuffer.size() < nSteps)
118 if (nStepBuffer.size() > nSteps)
119 nStepBuffer.pop_front();
122 assert(nStepBuffer.size() == nSteps);
127 state = nStepBuffer.front().state;
128 action = nStepBuffer.front().action;
130 states.col(position) = state.Encode();
131 actions[position] =
action;
132 rewards(position) =
reward;
133 nextStates.col(position) = nextState.Encode();
134 isTerminal(position) =
isEnd;
136 if (position == capacity)
154 const double& discount)
156 reward = nStepBuffer.back().reward;
157 nextState = nStepBuffer.back().nextState;
158 isEnd = nStepBuffer.back().isEnd;
161 for (
int i = nStepBuffer.size() - 2; i >= 0; i--)
163 bool iE = nStepBuffer[i].isEnd;
164 reward = nStepBuffer[i].reward + discount * reward * (1 - iE);
167 nextState = nStepBuffer[i].nextState;
184 std::vector<ActionType>& sampledActions,
185 arma::rowvec& sampledRewards,
186 arma::mat& sampledNextStates,
187 arma::irowvec& isTerminal)
189 size_t upperBound = full ? capacity : position;
190 arma::uvec sampledIndices = arma::randi<arma::uvec>(
191 batchSize, arma::distr_param(0, upperBound - 1));
193 sampledStates = states.cols(sampledIndices);
194 for (
size_t t = 0; t < sampledIndices.n_rows; t ++)
195 sampledActions.push_back(actions[sampledIndices[t]]);
196 sampledRewards = rewards.elem(sampledIndices).t();
197 sampledNextStates = nextStates.cols(sampledIndices);
198 isTerminal = this->isTerminal.elem(sampledIndices).t();
208 return full ? capacity : position;
220 std::vector<ActionType> ,
228 const size_t&
NSteps()
const {
return nSteps; }
247 std::deque<Transition> nStepBuffer;
253 std::vector<ActionType> actions;
256 arma::rowvec rewards;
259 arma::mat nextStates;
262 arma::irowvec isTerminal;
void Store(StateType state, ActionType action, double reward, StateType nextState, bool isEnd, const double &discount)
Store the given experience.
void Update(arma::mat, std::vector< ActionType >, arma::mat, arma::mat &)
Update the priorities of transitions and Update the gradients.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
typename EnvironmentType::State StateType
Convenient typedef for state.
RandomReplay(const size_t batchSize, const size_t capacity, const size_t nSteps=1, const size_t dimension=StateType::dimension)
Construct an instance of random experience replay class.
void Sample(arma::mat &sampledStates, std::vector< ActionType > &sampledActions, arma::rowvec &sampledRewards, arma::mat &sampledNextStates, arma::irowvec &isTerminal)
Sample some experiences.
const size_t & Size()
Get the number of transitions in the memory.
void GetNStepInfo(double &reward, StateType &nextState, bool &isEnd, const double &discount)
Get the reward, next state and terminal boolean for nth step.
const size_t & NSteps() const
Get the number of steps for n-step agent.
Implementation of random experience replay.