double_pole_cart.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
14 #define MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace rl {
20 
28 {
29  public:
35  class State
36  {
37  public:
41  State() : data(dimension)
42  { /* Nothing to do here. */ }
43 
49  State(const arma::colvec& data) : data(data)
50  { /* Nothing to do here */ }
51 
53  arma::colvec Data() const { return data; }
55  arma::colvec& Data() { return data; }
56 
58  double Position() const { return data[0]; }
60  double& Position() { return data[0]; }
61 
63  double Velocity() const { return data[1]; }
65  double& Velocity() { return data[1]; }
66 
68  double Angle(const size_t i) const { return data[2 * i]; }
70  double& Angle(const size_t i) { return data[2 * i]; }
71 
73  double AngularVelocity(const size_t i) const { return data[2 * i + 1]; }
75  double& AngularVelocity(const size_t i) { return data[2 * i + 1]; }
76 
78  const arma::colvec& Encode() const { return data; }
79 
81  static constexpr size_t dimension = 6;
82 
83  private:
85  arma::colvec data;
86  };
87 
91  class Action
92  {
93  public:
94  enum actions
95  {
97  forward
98  };
99  // To store the action.
101 
102  // Track the size of the action space.
103  static const size_t size = 2;
104  };
105 
123  DoublePoleCart(const size_t maxSteps = 0,
124  const double m1 = 0.1,
125  const double m2 = 0.01,
126  const double l1 = 0.5,
127  const double l2 = 0.05,
128  const double gravity = 9.8,
129  const double massCart = 1.0,
130  const double forceMag = 10.0,
131  const double tau = 0.02,
132  const double thetaThresholdRadians = 36 * 2 * 3.1416 / 360,
133  const double xThreshold = 2.4,
134  const double doneReward = 0.0) :
135  maxSteps(maxSteps),
136  m1(m1),
137  m2(m2),
138  l1(l1),
139  l2(l2),
140  gravity(gravity),
141  massCart(massCart),
142  forceMag(forceMag),
143  tau(tau),
144  thetaThresholdRadians(thetaThresholdRadians),
145  xThreshold(xThreshold),
146  doneReward(doneReward),
147  stepsPerformed(0)
148  { /* Nothing to do here */ }
149 
159  double Sample(const State& state,
160  const Action& action,
161  State& nextState)
162  {
163  // Update the number of steps performed.
164  stepsPerformed++;
165 
166  arma::vec dydx(6, arma::fill::zeros);
167  dydx[0] = state.Velocity();
168  dydx[2] = state.AngularVelocity(1);
169  dydx[4] = state.AngularVelocity(2);
170  Dsdt(state, action, dydx);
171  RK4(state, action, dydx, nextState);
172 
173  // Check if the episode has terminated.
174  bool done = IsTerminal(nextState);
175 
176  // Do not reward agent if it failed.
177  if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
178  return doneReward;
179  else if (done)
180  return 0;
181 
186  return 1.0;
187  }
188 
197  void Dsdt(const State& state,
198  const Action& action,
199  arma::vec& dydx)
200  {
201  double totalForce = action.action ? forceMag : -forceMag;
202  double totalMass = massCart;
203  double omega1 = state.AngularVelocity(1);
204  double omega2 = state.AngularVelocity(2);
205  double sinTheta1 = std::sin(state.Angle(1));
206  double sinTheta2 = std::sin(state.Angle(2));
207  double cosTheta1 = std::cos(state.Angle(1));
208  double cosTheta2 = std::cos(state.Angle(2));
209 
210  // Calculate total effective force.
211  totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
212  std::sin(2 * state.Angle(1));
213  totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
214  std::sin(2 * state.Angle(2));
215 
216  // Calculate total effective mass.
217  totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
218  totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
219 
220  // Calculate acceleration.
221  double xAcc = totalForce / totalMass;
222  dydx[1] = xAcc;
223 
224  // Calculate angular acceleration.
225  dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
226  dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
227  }
228 
238  void RK4(const State& state,
239  const Action& action,
240  arma::vec& dydx,
241  State& nextState)
242  {
243  const double hh = tau * 0.5;
244  const double h6 = tau / 6;
245  arma::vec yt(6);
246  arma::vec dyt(6);
247  arma::vec dym(6);
248 
249  yt = state.Data() + (hh * dydx);
250  Dsdt(State(yt), action, dyt);
251  dyt[0] = yt[1];
252  dyt[2] = yt[3];
253  dyt[4] = yt[5];
254  yt = state.Data() + (hh * dyt);
255 
256  Dsdt(State(yt), action, dym);
257  dym[0] = yt[1];
258  dym[2] = yt[3];
259  dym[4] = yt[5];
260  yt = state.Data() + (tau * dym);
261  dym += dyt;
262 
263  Dsdt(State(yt), action, dyt);
264  dyt[0] = yt[1];
265  dyt[2] = yt[3];
266  dyt[4] = yt[5];
267  nextState.Data() = state.Data() + h6 * (dydx + dyt + 2 * dym);
268  }
269 
278  double Sample(const State& state, const Action& action)
279  {
280  State nextState;
281  return Sample(state, action, nextState);
282  }
283 
290  {
291  stepsPerformed = 0;
292  return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
293  }
294 
301  bool IsTerminal(const State& state) const
302  {
303  if (maxSteps != 0 && stepsPerformed >= maxSteps)
304  {
305  Log::Info << "Episode terminated due to the maximum number of steps"
306  "being taken.";
307  return true;
308  }
309  if (std::abs(state.Position()) > xThreshold)
310  {
311  Log::Info << "Episode terminated due to cart crossing threshold";
312  return true;
313  }
314  if (std::abs(state.Angle(1)) > thetaThresholdRadians ||
315  std::abs(state.Angle(2)) > thetaThresholdRadians)
316  {
317  Log::Info << "Episode terminated due to pole falling";
318  return true;
319  }
320  return false;
321  }
322 
324  size_t StepsPerformed() const { return stepsPerformed; }
325 
327  size_t MaxSteps() const { return maxSteps; }
329  size_t& MaxSteps() { return maxSteps; }
330 
331  private:
333  size_t maxSteps;
334 
336  double m1;
337 
339  double m2;
340 
342  double l1;
343 
345  double l2;
346 
348  double gravity;
349 
351  double massCart;
352 
354  double forceMag;
355 
357  double tau;
358 
360  double thetaThresholdRadians;
361 
363  double xThreshold;
364 
366  double doneReward;
367 
369  size_t stepsPerformed;
370 };
371 
372 } // namespace rl
373 } // namespace mlpack
374 
375 #endif
double Sample(const State &state, const Action &action)
Dynamics of Double Pole Cart.
State(const arma::colvec &data)
Construct a state instance from given data.
size_t MaxSteps() const
Get the maximum number of steps allowed.
double & Velocity()
Modify the velocity of the cart.
Linear algebra utility functions, generally performed on matrices or vectors.
DoublePoleCart(const size_t maxSteps=0, const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0)
Construct a Double Pole Cart instance using the given constants.
Implementation of Double Pole Cart Balancing task.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method...
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Position() const
Get the position of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
size_t StepsPerformed() const
Get the number of steps performed.
size_t & MaxSteps()
Set the maximum number of steps allowed.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Double Pole Cart instance.
arma::colvec & Data()
Modify the internal representation of the state.
const arma::colvec & Encode() const
Encode the state to a vector..
Implementation of action of Double Pole Cart.
Implementation of the state of Double Pole Cart.
State()
Construct a state instance.
double Velocity() const
Get the velocity of the cart.
static constexpr size_t dimension
Dimension of the encoded state.
double & Position()
Modify the position of the cart.
arma::colvec Data() const
Get the internal representation of the state.