sac.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_SAC_HPP
14 #define MLPACK_METHODS_RL_SAC_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include <ensmallen.hpp>
18 
19 #include "replay/random_replay.hpp"
23 #include "training_config.hpp"
24 
25 namespace mlpack {
26 namespace rl {
27 
57 template <
58  typename EnvironmentType,
59  typename QNetworkType,
60  typename PolicyNetworkType,
61  typename UpdaterType,
62  typename ReplayType = RandomReplay<EnvironmentType>
63 >
64 class SAC
65 {
66  public:
68  using StateType = typename EnvironmentType::State;
69 
71  using ActionType = typename EnvironmentType::Action;
72 
89  SAC(TrainingConfig& config,
90  QNetworkType& learningQ1Network,
91  PolicyNetworkType& policyNetwork,
92  ReplayType& replayMethod,
93  UpdaterType qNetworkUpdater = UpdaterType(),
94  UpdaterType policyNetworkUpdater = UpdaterType(),
95  EnvironmentType environment = EnvironmentType());
96 
100  ~SAC();
101 
108  void SoftUpdate(double rho);
109 
113  void Update();
114 
118  void SelectAction();
119 
124  double Episode();
125 
127  size_t& TotalSteps() { return totalSteps; }
129  const size_t& TotalSteps() const { return totalSteps; }
130 
132  StateType& State() { return state; }
134  const StateType& State() const { return state; }
135 
137  const ActionType& Action() const { return action; }
138 
140  bool& Deterministic() { return deterministic; }
142  const bool& Deterministic() const { return deterministic; }
143 
144 
145  private:
147  TrainingConfig& config;
148 
150  QNetworkType& learningQ1Network;
151  QNetworkType learningQ2Network;
152 
154  QNetworkType targetQ1Network;
155  QNetworkType targetQ2Network;
156 
158  PolicyNetworkType& policyNetwork;
159 
161  ReplayType& replayMethod;
162 
164  UpdaterType qNetworkUpdater;
165  #if ENS_VERSION_MAJOR >= 2
166  typename UpdaterType::template Policy<arma::mat, arma::mat>*
167  qNetworkUpdatePolicy;
168  #endif
169 
171  UpdaterType policyNetworkUpdater;
172  #if ENS_VERSION_MAJOR >= 2
173  typename UpdaterType::template Policy<arma::mat, arma::mat>*
174  policyNetworkUpdatePolicy;
175  #endif
176 
178  EnvironmentType environment;
179 
181  size_t totalSteps;
182 
184  StateType state;
185 
187  ActionType action;
188 
190  bool deterministic;
191 
193  mlpack::ann::MeanSquaredError<> lossFunction;
194 };
195 
196 } // namespace rl
197 } // namespace mlpack
198 
199 // Include implementation
200 #include "sac_impl.hpp"
201 #endif
~SAC()
Clean memory.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: sac.hpp:71
void SelectAction()
Select an action, given an agent.
Linear algebra utility functions, generally performed on matrices or vectors.
Implementation of Soft Actor-Critic, a model-free off-policy actor-critic based deep reinforcement le...
Definition: sac.hpp:64
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Episode()
Execute an episode.
SAC(TrainingConfig &config, QNetworkType &learningQ1Network, PolicyNetworkType &policyNetwork, ReplayType &replayMethod, UpdaterType qNetworkUpdater=UpdaterType(), UpdaterType policyNetworkUpdater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the SAC object with given settings.
void Update()
Update the Q and policy networks.
const StateType & State() const
Get the state of the agent.
Definition: sac.hpp:134
size_t & TotalSteps()
Modify total steps from beginning.
Definition: sac.hpp:127
void SoftUpdate(double rho)
Softly update the learning Q network parameters to the target Q network parameters.
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: sac.hpp:140
const size_t & TotalSteps() const
Get total steps from beginning.
Definition: sac.hpp:129
The mean squared error performance function measures the network&#39;s performance according to the mean ...
const ActionType & Action() const
Get the action of the agent.
Definition: sac.hpp:137
StateType & State()
Modify the state of the agent.
Definition: sac.hpp:132
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: sac.hpp:142
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: sac.hpp:68