q_learning.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP
14 #define MLPACK_METHODS_RL_Q_LEARNING_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include <ensmallen.hpp>
18 
19 #include "replay/random_replay.hpp"
21 #include "training_config.hpp"
22 
23 namespace mlpack {
24 namespace rl {
25 
52 template <
53  typename EnvironmentType,
54  typename NetworkType,
55  typename UpdaterType,
56  typename PolicyType,
57  typename ReplayType = RandomReplay<EnvironmentType>
58 >
59 class QLearning
60 {
61  public:
63  using StateType = typename EnvironmentType::State;
64 
66  using ActionType = typename EnvironmentType::Action;
67 
81  QLearning(TrainingConfig& config,
82  NetworkType& network,
83  PolicyType& policy,
84  ReplayType& replayMethod,
85  UpdaterType updater = UpdaterType(),
86  EnvironmentType environment = EnvironmentType());
87 
91  ~QLearning();
92 
96  void TrainAgent();
97 
101  void TrainCategoricalAgent();
102 
106  void SelectAction();
107 
112  double Episode();
113 
115  size_t& TotalSteps() { return totalSteps; }
117  const size_t& TotalSteps() const { return totalSteps; }
118 
120  StateType& State() { return state; }
122  const StateType& State() const { return state; }
123 
125  const ActionType& Action() const { return action; }
126 
128  EnvironmentType& Environment() { return environment; }
130  const EnvironmentType& Environment() const { return environment; }
131 
133  bool& Deterministic() { return deterministic; }
135  const bool& Deterministic() const { return deterministic; }
136 
138  const NetworkType& Network() const { return learningNetwork; }
140  NetworkType& Network() { return learningNetwork; }
141 
142  private:
148  arma::Col<size_t> BestAction(const arma::mat& actionValues);
149 
151  TrainingConfig& config;
152 
154  NetworkType& learningNetwork;
155 
157  NetworkType targetNetwork;
158 
160  PolicyType& policy;
161 
163  ReplayType& replayMethod;
164 
166  UpdaterType updater;
167  #if ENS_VERSION_MAJOR >= 2
168  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
169  #endif
170 
172  EnvironmentType environment;
173 
175  size_t totalSteps;
176 
178  StateType state;
179 
181  ActionType action;
182 
184  bool deterministic;
185 };
186 
187 } // namespace rl
188 } // namespace mlpack
189 
190 // Include implementation
191 #include "q_learning_impl.hpp"
192 #endif
NetworkType & Network()
Modify the learning network.
Definition: q_learning.hpp:140
Linear algebra utility functions, generally performed on matrices or vectors.
void SelectAction()
Select an action, given an agent.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: q_learning.hpp:66
void TrainCategoricalAgent()
Trains the DQN agent of categorical type.
void TrainAgent()
Trains the DQN agent(non-categorical).
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: q_learning.hpp:133
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Episode()
Execute an episode.
size_t & TotalSteps()
Modify total steps from beginning.
Definition: q_learning.hpp:115
EnvironmentType & Environment()
Modify the environment in which the agent is.
Definition: q_learning.hpp:128
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:63
const EnvironmentType & Environment() const
Get the environment in which the agent is.
Definition: q_learning.hpp:130
const NetworkType & Network() const
Return the learning network.
Definition: q_learning.hpp:138
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: q_learning.hpp:135
const size_t & TotalSteps() const
Get total steps from beginning.
Definition: q_learning.hpp:117
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:59
StateType & State()
Modify the state of the agent.
Definition: q_learning.hpp:120
const ActionType & Action() const
Get the action of the agent.
Definition: q_learning.hpp:125
~QLearning()
Clean memory.
QLearning(TrainingConfig &config, NetworkType &network, PolicyType &policy, ReplayType &replayMethod, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the QLearning object with given settings.
const StateType & State() const
Get the state of the agent.
Definition: q_learning.hpp:122