13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP 14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP 16 #include <ensmallen.hpp> 31 typename EnvironmentType,
36 class OneStepQLearningWorker
53 const UpdaterType& updater,
54 const EnvironmentType& environment,
58 #
if ENS_VERSION_MAJOR >= 2
61 environment(environment),
63 deterministic(deterministic),
64 pending(config.UpdateInterval())
73 updater(other.updater),
74 #
if ENS_VERSION_MAJOR >= 2
77 environment(other.environment),
79 deterministic(other.deterministic),
81 episodeReturn(other.episodeReturn),
82 pending(other.pending),
83 pendingIndex(other.pendingIndex),
84 network(other.network),
87 #if ENS_VERSION_MAJOR >= 2 88 updatePolicy =
new typename UpdaterType::template
89 Policy<arma::mat, arma::mat>(updater,
90 network.Parameters().n_rows,
91 network.Parameters().n_cols);
103 updater(
std::move(other.updater)),
104 #
if ENS_VERSION_MAJOR >= 2
107 environment(
std::move(other.environment)),
108 config(
std::move(other.config)),
109 deterministic(
std::move(other.deterministic)),
110 steps(
std::move(other.steps)),
111 episodeReturn(
std::move(other.episodeReturn)),
112 pending(
std::move(other.pending)),
113 pendingIndex(
std::move(other.pendingIndex)),
114 network(
std::move(other.network)),
115 state(
std::move(other.state))
117 #if ENS_VERSION_MAJOR >= 2 118 other.updatePolicy = NULL;
120 updatePolicy =
new typename UpdaterType::template
121 Policy<arma::mat, arma::mat>(updater,
122 network.Parameters().n_rows,
123 network.Parameters().n_cols);
137 #if ENS_VERSION_MAJOR >= 2 141 updater = other.updater;
142 environment = other.environment;
143 config = other.config;
144 deterministic = other.deterministic;
146 episodeReturn = other.episodeReturn;
147 pending = other.pending;
148 pendingIndex = other.pendingIndex;
149 network = other.network;
152 #if ENS_VERSION_MAJOR >= 2 153 updatePolicy =
new typename UpdaterType::template
154 Policy<arma::mat, arma::mat>(updater,
155 network.Parameters().n_rows,
156 network.Parameters().n_cols);
174 #if ENS_VERSION_MAJOR >= 2 178 updater = std::move(other.updater);
179 environment = std::move(other.environment);
180 config = std::move(other.config);
181 deterministic = std::move(other.deterministic);
182 steps = std::move(other.steps);
183 episodeReturn = std::move(other.episodeReturn);
184 pending = std::move(other.pending);
185 pendingIndex = std::move(other.pendingIndex);
186 network = std::move(other.network);
187 state = std::move(other.state);
189 #if ENS_VERSION_MAJOR >= 2 190 other.updatePolicy = NULL;
192 updatePolicy =
new typename UpdaterType::template
193 Policy<arma::mat, arma::mat>(updater,
194 network.Parameters().n_rows,
195 network.Parameters().n_cols);
206 #if ENS_VERSION_MAJOR >= 2 217 #if ENS_VERSION_MAJOR == 1 218 updater.Initialize(learningNetwork.Parameters().n_rows,
219 learningNetwork.Parameters().n_cols);
223 updatePolicy =
new typename UpdaterType::template
224 Policy<arma::mat, arma::mat>(updater,
225 learningNetwork.Parameters().n_rows,
226 learningNetwork.Parameters().n_cols);
230 network = learningNetwork;
244 bool Step(NetworkType& learningNetwork,
245 NetworkType& targetNetwork,
251 arma::colvec actionValue;
252 network.Predict(state.Encode(), actionValue);
253 ActionType action = policy.Sample(actionValue, deterministic);
255 double reward = environment.Sample(state, action, nextState);
256 bool terminal = environment.IsTerminal(nextState);
258 episodeReturn += reward;
261 terminal = terminal || steps >= config.
StepLimit();
266 totalReward = episodeReturn;
269 network = learningNetwork;
279 pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
285 arma::mat totalGradients(learningNetwork.Parameters().n_rows,
286 learningNetwork.Parameters().n_cols, arma::fill::zeros);
287 for (
size_t i = 0; i < pending.size(); ++i)
292 arma::colvec actionValue;
295 targetNetwork.Predict(
296 std::get<3>(transition).Encode(), actionValue);
298 double targetActionValue = actionValue.max();
299 if (terminal && i == pending.size() - 1)
300 targetActionValue = 0;
301 targetActionValue = std::get<2>(transition) +
302 config.
Discount() * targetActionValue;
305 arma::mat input = std::get<0>(transition).Encode();
306 network.Forward(input, actionValue);
307 actionValue[std::get<1>(transition).action] = targetActionValue;
311 network.Backward(input, actionValue, gradients);
314 totalGradients += gradients;
318 totalGradients.transform(
320 {
return std::min(std::max(gradient, -config.
GradientLimit()),
324 #if ENS_VERSION_MAJOR == 1 325 updater.Update(learningNetwork.Parameters(), config.
StepSize(),
328 updatePolicy->Update(learningNetwork.Parameters(),
333 network = learningNetwork;
342 { targetNetwork = learningNetwork; }
349 totalReward = episodeReturn;
366 state = environment.InitialSample();
371 #if ENS_VERSION_MAJOR >= 2 372 typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
376 EnvironmentType environment;
388 double episodeReturn;
391 std::vector<TransitionType> pending;
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
~OneStepQLearningWorker()
Clean memory.
Linear algebra utility functions, generally performed on matrices or vectors.
std::tuple< StateType, ActionType, double, StateType > TransitionType
size_t StepLimit() const
Get the maximum steps of each episode.
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
OneStepQLearningWorker & operator=(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
OneStepQLearningWorker(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
cannot build Julia bindings endif() else() find_package(Julia 0.7.0) if(NOT JULIA_FOUND) unset(BUILD_JULIA_BINDINGS CACHE) endif() endif() if(NOT JULIA_FOUND) not_found_return("Julia not found
OneStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step Q-Learning worker with the given parameters and environment.
Forward declaration of OneStepQLearningWorker.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::Action ActionType
OneStepQLearningWorker(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
OneStepQLearningWorker & operator=(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
double GradientLimit() const
Get the limit of update gradient.
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
double StepSize() const
Get the step size of the optimizer.
typename EnvironmentType::State StateType