async_learning.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_RL_ASYNC_LEARNING_HPP
15 #define MLPACK_METHODS_RL_ASYNC_LEARNING_HPP
16 
17 #include <mlpack/prereqs.hpp>
21 #include "training_config.hpp"
22 
23 namespace mlpack {
24 namespace rl {
25 
50 template <
51  typename WorkerType,
52  typename EnvironmentType,
53  typename NetworkType,
54  typename UpdaterType,
55  typename PolicyType
56 >
58 {
59  public:
70  NetworkType network,
71  PolicyType policy,
72  UpdaterType updater = UpdaterType(),
73  EnvironmentType environment = EnvironmentType());
74 
88  template <typename Measure>
89  void Train(Measure& measure);
90 
92  TrainingConfig& Config() { return config; }
94  const TrainingConfig& Config() const { return config; }
95 
97  NetworkType& Network() { return learningNetwork; }
99  const NetworkType& Network() const { return learningNetwork; }
100 
102  PolicyType& Policy() { return policy; }
104  const PolicyType& Policy() const { return policy; }
105 
107  UpdaterType& Updater() { return updater; }
109  const UpdaterType& Updater() const { return updater; }
110 
112  EnvironmentType& Environment() { return environment; }
114  const EnvironmentType& Environment() const { return environment; }
115 
116  private:
118  TrainingConfig config;
119 
121  NetworkType learningNetwork;
122 
124  PolicyType policy;
125 
127  UpdaterType updater;
128 
130  EnvironmentType environment;
131 };
132 
141 template <
142  typename EnvironmentType,
143  typename NetworkType,
144  typename UpdaterType,
145  typename PolicyType
146 >
148 
157 template <
158  typename EnvironmentType,
159  typename NetworkType,
160  typename UpdaterType,
161  typename PolicyType
162 >
164 
173 template <
174  typename EnvironmentType,
175  typename NetworkType,
176  typename UpdaterType,
177  typename PolicyType
178 >
180 
189 template <
190  typename EnvironmentType,
191  typename NetworkType,
192  typename UpdaterType,
193  typename PolicyType
194 >
196  NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
197  UpdaterType, PolicyType>;
198 
207 template <
208  typename EnvironmentType,
209  typename NetworkType,
210  typename UpdaterType,
211  typename PolicyType
212 >
213 using OneStepSarsa = AsyncLearning<OneStepSarsaWorker<EnvironmentType,
214  NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
215  UpdaterType, PolicyType>;
216 
225 template <
226  typename EnvironmentType,
227  typename NetworkType,
228  typename UpdaterType,
229  typename PolicyType
230 >
231 using NStepQLearning = AsyncLearning<NStepQLearningWorker<EnvironmentType,
232  NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
233  UpdaterType, PolicyType>;
234 
235 } // namespace rl
236 } // namespace mlpack
237 
238 // Include implementation
239 #include "async_learning_impl.hpp"
240 
241 #endif
const NetworkType & Network() const
Modify learning network.
Linear algebra utility functions, generally performed on matrices or vectors.
EnvironmentType & Environment()
Get the environment.
PolicyType & Policy()
Get behavior policy.
The core includes that mlpack expects; standard C++ includes and Armadillo.
const PolicyType & Policy() const
Modify behavior policy.
UpdaterType & Updater()
Get optimizer.
Forward declaration of OneStepQLearningWorker.
const EnvironmentType & Environment() const
Modify the environment.
AsyncLearning(TrainingConfig config, NetworkType network, PolicyType policy, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Construct an instance of the given async learning algorithm.
Forward declaration of NStepQLearningWorker.
TrainingConfig & Config()
Get training config.
Wrapper of various asynchronous learning algorithms, e.g.
const UpdaterType & Updater() const
Modify optimizer.
Forward declaration of OneStepSarsaWorker.
void Train(Measure &measure)
Starting async training.
NetworkType & Network()
Get learning network.
const TrainingConfig & Config() const
Modify training config.