13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP 14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP 16 #include <ensmallen.hpp> 31 typename EnvironmentType,
36 class OneStepSarsaWorker
54 const UpdaterType& updater,
55 const EnvironmentType& environment,
59 #
if ENS_VERSION_MAJOR >= 2
62 environment(environment),
64 deterministic(deterministic),
65 pending(config.UpdateInterval())
74 updater(other.updater),
75 #
if ENS_VERSION_MAJOR >= 2
78 environment(other.environment),
80 deterministic(other.deterministic),
82 episodeReturn(other.episodeReturn),
83 pending(other.pending),
84 pendingIndex(other.pendingIndex),
85 network(other.network),
91 #if ENS_VERSION_MAJOR >= 2 92 updatePolicy =
new typename UpdaterType::template
93 Policy<arma::mat, arma::mat>(updater,
94 network.Parameters().n_rows,
95 network.Parameters().n_cols);
105 updater(
std::move(other.updater)),
106 #
if ENS_VERSION_MAJOR >= 2
109 environment(
std::move(other.environment)),
110 config(
std::move(other.config)),
111 deterministic(
std::move(other.deterministic)),
112 steps(
std::move(other.steps)),
113 episodeReturn(
std::move(other.episodeReturn)),
114 pending(
std::move(other.pending)),
115 pendingIndex(
std::move(other.pendingIndex)),
116 network(
std::move(other.network)),
117 state(
std::move(other.state)),
118 action(
std::move(other.action))
120 #if ENS_VERSION_MAJOR >= 2 121 other.updatePolicy = NULL;
123 updatePolicy =
new typename UpdaterType::template
124 Policy<arma::mat, arma::mat>(updater,
125 network.Parameters().n_rows,
126 network.Parameters().n_cols);
140 #if ENS_VERSION_MAJOR >= 2 144 updater = other.updater;
145 environment = other.environment;
146 config = other.config;
147 deterministic = other.deterministic;
149 episodeReturn = other.episodeReturn;
150 pending = other.pending;
151 pendingIndex = other.pendingIndex;
152 network = other.network;
154 action = other.action;
156 #if ENS_VERSION_MAJOR >= 2 157 updatePolicy =
new typename UpdaterType::template
158 Policy<arma::mat, arma::mat>(updater,
159 network.Parameters().n_rows,
160 network.Parameters().n_cols);
178 #if ENS_VERSION_MAJOR >= 2 182 updater = std::move(other.updater);
183 environment = std::move(other.environment);
184 config = std::move(other.config);
185 deterministic = std::move(other.deterministic);
186 steps = std::move(other.steps);
187 episodeReturn = std::move(other.episodeReturn);
188 pending = std::move(other.pending);
189 pendingIndex = std::move(other.pendingIndex);
190 network = std::move(other.network);
191 state = std::move(other.state);
192 action = std::move(other.action);
194 #if ENS_VERSION_MAJOR >= 2 195 other.updatePolicy = NULL;
197 updatePolicy =
new typename UpdaterType::template
198 Policy<arma::mat, arma::mat>(updater,
199 network.Parameters().n_rows,
200 network.Parameters().n_cols);
211 #if ENS_VERSION_MAJOR >= 2 222 #if ENS_VERSION_MAJOR == 1 223 updater.Initialize(learningNetwork.Parameters().n_rows,
224 learningNetwork.Parameters().n_cols);
228 updatePolicy =
new typename UpdaterType::template
229 Policy<arma::mat, arma::mat>(updater,
230 learningNetwork.Parameters().n_rows,
231 learningNetwork.Parameters().n_cols);
235 network = learningNetwork;
249 bool Step(NetworkType& learningNetwork,
250 NetworkType& targetNetwork,
256 if (action.action == ActionType::size)
259 arma::colvec actionValue;
260 network.Predict(state.Encode(), actionValue);
261 action = policy.Sample(actionValue, deterministic);
264 double reward = environment.Sample(state, action, nextState);
265 bool terminal = environment.IsTerminal(nextState);
266 arma::colvec actionValue;
267 network.Predict(nextState.Encode(), actionValue);
268 ActionType nextAction = policy.Sample(actionValue, deterministic);
270 episodeReturn += reward;
273 terminal = terminal || steps >= config.
StepLimit();
278 totalReward = episodeReturn;
281 network = learningNetwork;
292 pending[pendingIndex++] =
293 std::make_tuple(state, action, reward, nextState, nextAction);
298 arma::mat totalGradients(learningNetwork.Parameters().n_rows,
299 learningNetwork.Parameters().n_cols, arma::fill::zeros);
300 for (
size_t i = 0; i < pending.size(); ++i)
305 arma::colvec actionValue;
308 targetNetwork.Predict(
309 std::get<3>(transition).Encode(), actionValue);
311 double targetActionValue = 0;
312 if (!(terminal && i == pending.size() - 1))
313 targetActionValue = actionValue[std::get<4>(transition).action];
314 targetActionValue = std::get<2>(transition) +
315 config.
Discount() * targetActionValue;
318 arma::mat input = std::get<0>(transition).Encode();
319 network.Forward(input, actionValue);
320 actionValue[std::get<1>(transition).action] = targetActionValue;
324 network.Backward(input, actionValue, gradients);
327 totalGradients += gradients;
331 totalGradients.transform(
333 {
return std::min(std::max(gradient, -config.
GradientLimit()),
337 #if ENS_VERSION_MAJOR == 1 338 updater.Update(learningNetwork.Parameters(), config.
StepSize(),
341 updatePolicy->Update(learningNetwork.Parameters(),
346 network = learningNetwork;
355 { targetNetwork = learningNetwork; }
362 totalReward = episodeReturn;
380 state = environment.InitialSample();
381 using actions =
typename EnvironmentType::Action::actions;
382 action.action =
static_cast<actions
>(ActionType::size);
387 #if ENS_VERSION_MAJOR >= 2 388 typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
392 EnvironmentType environment;
404 double episodeReturn;
407 std::vector<TransitionType> pending;
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
OneStepSarsaWorker(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
Linear algebra utility functions, generally performed on matrices or vectors.
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
~OneStepSarsaWorker()
Clean memory.
size_t StepLimit() const
Get the maximum steps of each episode.
std::tuple< StateType, ActionType, double, StateType, ActionType > TransitionType
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
OneStepSarsaWorker & operator=(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
OneStepSarsaWorker & operator=(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
cannot build Julia bindings endif() else() find_package(Julia 0.7.0) if(NOT JULIA_FOUND) unset(BUILD_JULIA_BINDINGS CACHE) endif() endif() if(NOT JULIA_FOUND) not_found_return("Julia not found
OneStepSarsaWorker(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::State StateType
double GradientLimit() const
Get the limit of update gradient.
typename EnvironmentType::Action ActionType
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Forward declaration of OneStepSarsaWorker.
double StepSize() const
Get the step size of the optimizer.
OneStepSarsaWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step sarsa worker with the given parameters and environment.